diff --git a/cipher_suite.go b/cipher_suite.go index af95dec2e..1d4f57b67 100644 --- a/cipher_suite.go +++ b/cipher_suite.go @@ -4,6 +4,7 @@ package dtls import ( + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -258,11 +259,16 @@ func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []Ciph if cert == nil || cert.PrivateKey == nil { return cipherSuites } + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return cipherSuites + } + var certType clientcertificate.Type - switch cert.PrivateKey.(type) { - case ed25519.PrivateKey, *ecdsa.PrivateKey: + switch signer.Public().(type) { + case ed25519.PublicKey, *ecdsa.PublicKey: certType = clientcertificate.ECDSASign - case *rsa.PrivateKey: + case *rsa.PublicKey: certType = clientcertificate.RSASign } diff --git a/config.go b/config.go index 54a86c0ee..161d742a5 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ package dtls import ( + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -273,10 +274,14 @@ func validateConfig(config *Config) error { return errInvalidCertificate } if cert.PrivateKey != nil { - switch cert.PrivateKey.(type) { - case ed25519.PrivateKey: - case *ecdsa.PrivateKey: - case *rsa.PrivateKey: + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return errInvalidPrivateKey + } + switch signer.Public().(type) { + case ed25519.PublicKey: + case *ecdsa.PublicKey: + case *rsa.PublicKey: default: return errInvalidPrivateKey } diff --git a/conn_test.go b/conn_test.go index 9edb3aee0..b7b5bca17 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3074,21 +3074,21 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { }() var ( - priv crypto.PrivateKey - err error + signer crypto.Signer + err error ) if test.generateRSA { - if priv, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { + if signer, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { t.Fatal(err) } } else { - if priv, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { + if signer, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { t.Fatal(err) } } - serverCert, err := selfsign.SelfSign(priv) + serverCert, err := selfsign.SelfSign(signer) if err != nil { t.Fatal(err) } diff --git a/crypto.go b/crypto.go index 25b2a1f9f..2dc8b7278 100644 --- a/crypto.go +++ b/crypto.go @@ -43,18 +43,18 @@ func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve el // hash/signature algorithm pair that appears in that extension // // https://tools.ietf.org/html/rfc5246#section-7.4.2 -func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { +func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, signer crypto.Signer, hashAlgorithm hash.Algorithm) ([]byte, error) { msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) - switch p := privateKey.(type) { - case ed25519.PrivateKey: + switch signer.Public().(type) { + case ed25519.PublicKey: // https://crypto.stackexchange.com/a/55483 - return p.Sign(rand.Reader, msg, crypto.Hash(0)) - case *ecdsa.PrivateKey: + return signer.Sign(rand.Reader, msg, crypto.Hash(0)) + case *ecdsa.PublicKey: hashed := hashAlgorithm.Digest(msg) - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) - case *rsa.PrivateKey: + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + case *rsa.PublicKey: hashed := hashAlgorithm.Digest(msg) - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errKeySignatureGenerateUnimplemented @@ -107,21 +107,21 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A // CertificateVerify message is sent to explicitly verify possession of // the private key in the certificate. // https://tools.ietf.org/html/rfc5246#section-7.3 -func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { - if p, ok := privateKey.(ed25519.PrivateKey); ok { +func generateCertificateVerify(handshakeBodies []byte, signer crypto.Signer, hashAlgorithm hash.Algorithm) ([]byte, error) { + if _, ok := signer.Public().(ed25519.PublicKey); ok { // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign // Sign signs the given message with priv. Ed25519 performs two passes over // messages to be signed and therefore cannot handle pre-hashed messages. - return p.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) + return signer.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) } hashed := hashAlgorithm.Digest(handshakeBodies) - switch p := privateKey.(type) { - case *ecdsa.PrivateKey: - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) - case *rsa.PrivateKey: - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + switch signer.Public().(type) { + case *ecdsa.PublicKey: + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + case *rsa.PublicKey: + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errInvalidSignatureAlgorithm diff --git a/flight4handler.go b/flight4handler.go index 7e4ae12f1..f58e65242 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -5,6 +5,7 @@ package dtls import ( "context" + "crypto" "crypto/rand" "crypto/x509" @@ -331,13 +332,18 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() + signer, ok := certificate.PrivateKey.(crypto.Signer) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidPrivateKey + } + // Find compatible signature scheme - signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey) + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } - signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash) + signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, signer, signatureHashAlgo.Hash) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } diff --git a/flight5handler.go b/flight5handler.go index 7e940cdc9..02887ed84 100644 --- a/flight5handler.go +++ b/flight5handler.go @@ -66,7 +66,7 @@ func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshak } func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit - var privateKey crypto.PrivateKey + var signer crypto.Signer var pkts []*packet if state.remoteRequestedCertificate { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, @@ -88,7 +88,10 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain } if certificate.Certificate != nil { - privateKey = certificate.PrivateKey + signer, ok = certificate.PrivateKey.(crypto.Signer) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errInvalidPrivateKey + } } pkts = append(pkts, &packet{ @@ -180,7 +183,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han // If the client has sent a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of the // private key in the certificate. - if state.remoteRequestedCertificate && privateKey != nil { + if state.remoteRequestedCertificate && signer != nil { plainText := append(cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, @@ -194,12 +197,12 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han // Find compatible signature scheme - signatureHashAlgo, err := signaturehash.SelectSignatureScheme(state.remoteCertRequestAlgs, privateKey) + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(state.remoteCertRequestAlgs, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } - certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash) + certVerify, err := generateCertificateVerify(plainText, signer, signatureHashAlgo.Hash) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } diff --git a/pkg/crypto/selfsign/selfsign.go b/pkg/crypto/selfsign/selfsign.go index 6ef016724..c8b365580 100644 --- a/pkg/crypto/selfsign/selfsign.go +++ b/pkg/crypto/selfsign/selfsign.go @@ -53,13 +53,18 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, maxBigInt = new(big.Int) // Max random value, a 130-bits integer, i.e 2^130 - 1 ) - switch k := key.(type) { - case ed25519.PrivateKey: - pubKey = k.Public() - case *ecdsa.PrivateKey: - pubKey = k.Public() - case *rsa.PrivateKey: - pubKey = k.Public() + signer, ok := key.(crypto.Signer) + if !ok { + return tls.Certificate{}, errInvalidPrivateKey + } + + switch k := signer.Public().(type) { + case ed25519.PublicKey: + pubKey = k + case *ecdsa.PublicKey: + pubKey = k + case *rsa.PublicKey: + pubKey = k default: return tls.Certificate{}, errInvalidPrivateKey } @@ -76,7 +81,7 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, names = append(names, sans...) keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign - if _, isRSA := key.(*rsa.PrivateKey); isRSA { + if _, isRSA := signer.Public().(*rsa.PublicKey); isRSA { keyUsage |= x509.KeyUsageKeyEncipherment } @@ -98,7 +103,7 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, }, } - raw, err := x509.CreateCertificate(rand.Reader, &template, &template, pubKey, key) + raw, err := x509.CreateCertificate(rand.Reader, &template, &template, pubKey, signer) if err != nil { return tls.Certificate{}, err } @@ -110,7 +115,7 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, return tls.Certificate{ Certificate: [][]byte{raw}, - PrivateKey: key, + PrivateKey: signer, Leaf: leaf, }, nil } diff --git a/pkg/crypto/signaturehash/errors.go b/pkg/crypto/signaturehash/errors.go index 4aeb3e40a..2e2b72bb8 100644 --- a/pkg/crypto/signaturehash/errors.go +++ b/pkg/crypto/signaturehash/errors.go @@ -9,4 +9,5 @@ var ( errNoAvailableSignatureSchemes = errors.New("connection can not be created, no SignatureScheme satisfy this Config") errInvalidSignatureAlgorithm = errors.New("invalid signature algorithm") errInvalidHashAlgorithm = errors.New("invalid hash algorithm") + errInvalidPrivateKey = errors.New("invalid private key type") ) diff --git a/pkg/crypto/signaturehash/signaturehash.go b/pkg/crypto/signaturehash/signaturehash.go index 38587768b..0b52f7699 100644 --- a/pkg/crypto/signaturehash/signaturehash.go +++ b/pkg/crypto/signaturehash/signaturehash.go @@ -40,8 +40,12 @@ func Algorithms() []Algorithm { // SelectSignatureScheme returns most preferred and compatible scheme. func SelectSignatureScheme(sigs []Algorithm, privateKey crypto.PrivateKey) (Algorithm, error) { + signer, ok := privateKey.(crypto.Signer) + if !ok { + return Algorithm{}, errInvalidPrivateKey + } for _, ss := range sigs { - if ss.isCompatible(privateKey) { + if ss.isCompatible(signer) { return ss, nil } } @@ -49,13 +53,13 @@ func SelectSignatureScheme(sigs []Algorithm, privateKey crypto.PrivateKey) (Algo } // isCompatible checks that given private key is compatible with the signature scheme. -func (a *Algorithm) isCompatible(privateKey crypto.PrivateKey) bool { - switch privateKey.(type) { - case ed25519.PrivateKey: +func (a *Algorithm) isCompatible(signer crypto.Signer) bool { + switch signer.Public().(type) { + case ed25519.PublicKey: return a.Signature == signature.Ed25519 - case *ecdsa.PrivateKey: + case *ecdsa.PublicKey: return a.Signature == signature.ECDSA - case *rsa.PrivateKey: + case *rsa.PublicKey: return a.Signature == signature.RSA default: return false