diff --git a/ecdsa.go b/ecdsa.go index f9773812..d1f1fe5f 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -53,8 +53,9 @@ func (m *SigningMethodECDSA) Alg() string { return m.Name } -// Implements the Verify method from SigningMethod -// For this verify method, key must be an ecdsa.PublicKey struct +// Implements the Verify method from SigningMethod. +// For this verify method, key must be in types of either *ecdsa.PublicKey or +// []*ecdsa.PublicKey (for rotation keys). func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -64,15 +65,6 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa return err } - // Get the key - var ecdsaKey *ecdsa.PublicKey - switch k := key.(type) { - case *ecdsa.PublicKey: - ecdsaKey = k - default: - return ErrInvalidKeyType - } - if len(sig) != 2*m.KeySize { return ErrECDSAVerification } @@ -80,19 +72,38 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa r := big.NewInt(0).SetBytes(sig[:m.KeySize]) s := big.NewInt(0).SetBytes(sig[m.KeySize:]) - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - // Verify the signature - if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { - return nil - } else { + f := func(ecdsaKey *ecdsa.PublicKey) error { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { + return nil + } return ErrECDSAVerification } + + // Get the key + switch v := key.(type) { + case *ecdsa.PublicKey: + return f(v) + case []*ecdsa.PublicKey: + var lastErr error + for _, ecdsaKey := range v { + lastErr = f(ecdsaKey) + if lastErr == nil { + return nil + } + } + return lastErr + default: + return ErrInvalidKeyType + } } // Implements the Sign method from SigningMethod diff --git a/hmac.go b/hmac.go index addbe5d4..8fed472b 100644 --- a/hmac.go +++ b/hmac.go @@ -47,12 +47,6 @@ func (m *SigningMethodHMAC) Alg() string { // Verify the signature of HSXXX tokens. Returns nil if the signature is valid. func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error { - // Verify the key is the right type - keyBytes, ok := key.([]byte) - if !ok { - return ErrInvalidKeyType - } - // Decode signature, for comparison sig, err := DecodeSegment(signature) if err != nil { @@ -64,17 +58,37 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac return ErrHashUnavailable } - // This signing method is symmetric, so we validate the signature - // by reproducing the signature from the signing string and key, then - // comparing that against the provided signature. - hasher := hmac.New(m.Hash.New, keyBytes) - hasher.Write([]byte(signingString)) - if !hmac.Equal(sig, hasher.Sum(nil)) { - return ErrSignatureInvalid + // verifications to be done with each key. + f := func(keyBytes []byte) error { + // This signing method is symmetric, so we validate the signature + // by reproducing the signature from the signing string and key, then + // comparing that against the provided signature. + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + if !hmac.Equal(sig, hasher.Sum(nil)) { + return ErrSignatureInvalid + } + + // No validation errors. Signature is good. + return nil } - // No validation errors. Signature is good. - return nil + // Verify the key is the right type + switch v := key.(type) { + case []byte: + return f(v) + case [][]byte: + var lastErr error + for _, keyBytes := range v { + lastErr = f(keyBytes) + if lastErr == nil { + return nil + } + } + return lastErr + default: + return ErrInvalidKeyType + } } // Implements the Sign method from SigningMethod for this signing method. diff --git a/rsa.go b/rsa.go index e4caf1ca..0c2b9b9c 100644 --- a/rsa.go +++ b/rsa.go @@ -45,7 +45,8 @@ func (m *SigningMethodRSA) Alg() string { } // Implements the Verify method from SigningMethod -// For this signing method, must be an *rsa.PublicKey structure. +// For this signing method, key must be in types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -55,22 +56,34 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface return err } - var rsaKey *rsa.PublicKey - var ok bool - - if rsaKey, ok = key.(*rsa.PublicKey); !ok { - return ErrInvalidKeyType - } - - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - // Verify the signature - return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + f := func(rsaKey *rsa.PublicKey) error { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + } + + switch v := key.(type) { + case *rsa.PublicKey: + return f(v) + case []*rsa.PublicKey: + var lastErr error + for _, rsaKey := range v { + lastErr = f(rsaKey) + if lastErr == nil { + return nil + } + } + return lastErr + default: + return ErrInvalidKeyType + } } // Implements the Sign method from SigningMethod diff --git a/rsa_pss.go b/rsa_pss.go index c0147086..a89e221e 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -80,7 +80,8 @@ func init() { } // Implements the Verify method from SigningMethod -// For this verify method, key must be an rsa.PublicKey struct +// For this verify method, key must be in the types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error { var err error @@ -90,27 +91,38 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf return err } - var rsaKey *rsa.PublicKey - switch k := key.(type) { - case *rsa.PublicKey: - rsaKey = k - default: - return ErrInvalidKey - } - - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - opts := m.Options - if m.VerifyOptions != nil { - opts = m.VerifyOptions + f := func(rsaKey *rsa.PublicKey) error { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + opts := m.Options + if m.VerifyOptions != nil { + opts = m.VerifyOptions + } + + return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) } - return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + switch v := key.(type) { + case *rsa.PublicKey: + return f(v) + case []*rsa.PublicKey: + var lastErr error + for _, rsaKey := range v { + lastErr = f(rsaKey) + if lastErr == nil { + return nil + } + } + return lastErr + default: + return ErrInvalidKey + } } // Implements the Sign method from SigningMethod