From 697b56a086c7e43be76460e75b62a26be372fedd Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 16 Jan 2023 00:35:23 +0000 Subject: [PATCH 01/34] Initial commit for TLS implementation. working on the server side. --- acceptor.go | 10 + connection.go | 46 + eventloop.go | 36 + go.mod | 1 + go.sum | 2 + options.go | 11 + pkg/tls/alert.go | 99 ++ pkg/tls/auth.go | 289 +++++ pkg/tls/buf.go | 209 ++++ pkg/tls/cipher_suites.go | 537 +++++++++ pkg/tls/common.go | 1450 +++++++++++++++++++++++ pkg/tls/conn.go | 1260 ++++++++++++++++++++ pkg/tls/handshake_client.go | 1034 +++++++++++++++++ pkg/tls/handshake_client_tls13.go | 675 +++++++++++ pkg/tls/handshake_messages.go | 1809 +++++++++++++++++++++++++++++ pkg/tls/handshake_server.go | 869 ++++++++++++++ pkg/tls/handshake_server_tls13.go | 869 ++++++++++++++ pkg/tls/key_agreement.go | 334 ++++++ pkg/tls/key_schedule.go | 199 ++++ pkg/tls/prf.go | 283 +++++ pkg/tls/ticket.go | 185 +++ pkg/tls/tls.go | 205 ++++ 22 files changed, 10412 insertions(+) create mode 100644 pkg/tls/alert.go create mode 100644 pkg/tls/auth.go create mode 100644 pkg/tls/buf.go create mode 100644 pkg/tls/cipher_suites.go create mode 100644 pkg/tls/common.go create mode 100644 pkg/tls/conn.go create mode 100644 pkg/tls/handshake_client.go create mode 100644 pkg/tls/handshake_client_tls13.go create mode 100644 pkg/tls/handshake_messages.go create mode 100644 pkg/tls/handshake_server.go create mode 100644 pkg/tls/handshake_server_tls13.go create mode 100644 pkg/tls/key_agreement.go create mode 100644 pkg/tls/key_schedule.go create mode 100644 pkg/tls/prf.go create mode 100644 pkg/tls/ticket.go create mode 100644 pkg/tls/tls.go diff --git a/acceptor.go b/acceptor.go index d4fa0bdf1..0229ee48d 100644 --- a/acceptor.go +++ b/acceptor.go @@ -50,6 +50,11 @@ func (eng *engine) accept(fd int, _ netpoll.IOEvent) error { el := eng.lb.next(remoteAddr) c := newTCPConn(nfd, el, sa, el.ln.addr, remoteAddr) + if el.engine.opts.TLSconfig != nil { + if err = c.UpgradeTLS(el.engine.opts.TLSconfig); err != nil { + return err + } + } err = el.poller.UrgentTrigger(el.register, c) if err != nil { @@ -84,6 +89,11 @@ func (el *eventloop) accept(fd int, ev netpoll.IOEvent) error { } c := newTCPConn(nfd, el, sa, el.ln.addr, remoteAddr) + if el.engine.opts.TLSconfig != nil { + if err = c.UpgradeTLS(el.engine.opts.TLSconfig); err != nil { + return err + } + } if err = el.poller.AddRead(c.pollAttachment); err != nil { return err } diff --git a/connection.go b/connection.go index ae76bd650..94d7e61ad 100644 --- a/connection.go +++ b/connection.go @@ -33,6 +33,7 @@ import ( "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" gerrors "github.com/panjf2000/gnet/v2/pkg/errors" bsPool "github.com/panjf2000/gnet/v2/pkg/pool/byteslice" + "github.com/panjf2000/gnet/v2/pkg/tls" ) type conn struct { @@ -48,6 +49,7 @@ type conn struct { fd int // file descriptor isDatagram bool // UDP protocol opened bool // connection opened event fired + tlsconn *tls.Conn // tls connection } func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, localAddr, remoteAddr net.Addr) (c *conn) { @@ -129,6 +131,17 @@ func (c *conn) open(buf []byte) error { func (c *conn) write(data []byte) (n int, err error) { n = len(data) + + if c.tlsconn != nil { + // use tls to encrypt the data before sending it + c.tlsconn.Write(data) + // err = c.loop.poller.ModReadWrite(c.pollAttachment) + // n = 0 + // also working + err = c.loop.write(c) + return + } + // If there is pending data in outbound buffer, the current data ought to be appended to the outbound buffer // for maintaining the sequence of network packets. if !c.outboundBuffer.IsEmpty() { @@ -159,6 +172,18 @@ func (c *conn) writev(bs [][]byte) (n int, err error) { n += len(b) } + if c.tlsconn != nil { + for _, b := range bs { + // use tls to encrypt the data before sending it + c.tlsconn.Write(b) + } + // err = c.loop.poller.ModReadWrite(c.pollAttachment) + // n = 0 + // also working + err = c.loop.write(c) + return + } + // If there is pending data in outbound buffer, the current data ought to be appended to the outbound buffer // for maintaining the sequence of network packets. if !c.outboundBuffer.IsEmpty() { @@ -467,3 +492,24 @@ func (c *conn) Close() error { return }, nil) } + +func (c *conn) UpgradeTLS(config *tls.Config) (err error) { + c.tlsconn, err = tls.Server(c, &c.inboundBuffer, c.outboundBuffer, config.Clone()) + + //很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 + if c.inboundBuffer.Len() > 0 { + c.tlsconn.RawWrite(c.inboundBuffer.Bytes()) + c.inboundBuffer.Reset() + if err := c.tlsconn.Handshake(); err != nil { + return err + } + } + + //握手失败的关了 + time.AfterFunc(time.Second*5, func() { + if c.opened && (c.tlsconn == nil || !c.tlsconn.HandshakeComplete()) { + c.Close() + } + }) + return err +} diff --git a/eventloop.go b/eventloop.go index 8bfad8017..7b0fff0bc 100644 --- a/eventloop.go +++ b/eventloop.go @@ -126,6 +126,42 @@ func (el *eventloop) read(c *conn) error { return el.closeConn(c, os.NewSyscallError("read", err)) } + if c.tlsconn != nil { + c.tlsconn.RawWrite(el.buffer[:n]) + if !c.tlsconn.HandshakeComplete() { + //先判断是否足够一条消息 + data := c.tlsconn.RawData() + if len(data) < 5 || len(data) < 5+int(data[3])<<8|int(data[4]) { + return nil + } + if err = c.tlsconn.Handshake(); err != nil { + return el.closeConn(c, os.NewSyscallError("TLS handshake", err)) + } + if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawData()) == 0 { //握手没成功,或者握手成功,但是没有数据黏包了 + c.Flush() + return nil + } + } + + if err = c.tlsconn.ReadFrame(); err != nil { + return el.closeConn(c, os.NewSyscallError("TLS read", err)) + } + + if c.inboundBuffer.IsEmpty() { + return nil + } + + action := el.eventHandler.OnTraffic(c) + switch action { + case None: + case Close: + return el.closeConn(c, nil) + case Shutdown: + return gerrors.ErrEngineShutdown + } + return nil + } + c.buffer = el.buffer[:n] action := el.eventHandler.OnTraffic(c) switch action { diff --git a/go.mod b/go.mod index 1ef79846f..ac113fc99 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/zap v1.21.0 + golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 golang.org/x/sys v0.3.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) diff --git a/go.sum b/go.sum index 0c47ca9cf..8a93adf87 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E= +golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/options.go b/options.go index d941b23bb..1609aed33 100644 --- a/options.go +++ b/options.go @@ -18,6 +18,7 @@ import ( "time" "github.com/panjf2000/gnet/v2/pkg/logging" + "github.com/panjf2000/gnet/v2/pkg/tls" ) // Option is a function that will set up option. @@ -109,6 +110,9 @@ type Options struct { // SocketSendBuffer sets the maximum socket send buffer in bytes. SocketSendBuffer int + // TLSconfig sets the configuration of a TLS connection + TLSconfig *tls.Config + // LogPath the local path where logs will be written, this is the easiest way to set up logging, // gnet instantiates a default uber-go/zap logger with this given log path, you are also allowed to employ // you own logger during the lifetime by implementing the following log.Logger interface. @@ -249,3 +253,10 @@ func WithMulticastInterfaceIndex(idx int) Option { opts.MulticastInterfaceIndex = idx } } + +// WithTLS sets the tls configuration which includes the cert, the key, the cipher suite, the rotocol version, and etc. +func WithTLS(tlsconfig *tls.Config) Option { + return func(opts *Options) { + opts.TLSconfig = tlsconfig + } +} diff --git a/pkg/tls/alert.go b/pkg/tls/alert.go new file mode 100644 index 000000000..4790b7372 --- /dev/null +++ b/pkg/tls/alert.go @@ -0,0 +1,99 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import "strconv" + +type alert uint8 + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertExportRestriction alert = 60 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertMissingExtension alert = 109 + alertUnsupportedExtension alert = 110 + alertCertificateUnobtainable alert = 111 + alertUnrecognizedName alert = 112 + alertBadCertificateStatusResponse alert = 113 + alertBadCertificateHashValue alert = 114 + alertUnknownPSKIdentity alert = 115 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertExportRestriction: "export restriction", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertMissingExtension: "missing extension", + alertUnsupportedExtension: "unsupported extension", + alertCertificateUnobtainable: "certificate unobtainable", + alertUnrecognizedName: "unrecognized name", + alertBadCertificateStatusResponse: "bad certificate status response", + alertBadCertificateHashValue: "bad certificate hash value", + alertUnknownPSKIdentity: "unknown PSK identity", + alertCertificateRequired: "certificate required", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/pkg/tls/auth.go b/pkg/tls/auth.go new file mode 100644 index 000000000..ad5f9a2e4 --- /dev/null +++ b/pkg/tls/auth.go @@ -0,0 +1,289 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" +) + +// verifyHandshakeSignature verifies a signature against pre-hashed +// (if required) handshake contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, signed, sig) { + return errors.New("ECDSA verification failure") + } + case signatureEd25519: + pubKey, ok := pubkey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) + } + if !ed25519.Verify(pubKey, signed, sig) { + return errors.New("Ed25519 verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + +const ( + serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" + clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" +) + +var signaturePadding = []byte{ + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, +} + +// signedMessage returns the pre-hashed (if necessary) message to be signed by +// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { + if sigHash == directSigning { + b := &bytes.Buffer{} + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() + } + h := sigHash.New() + h.Write(signaturePadding) + io.WriteString(h, context) + h.Write(transcript.Sum(nil)) + return h.Sum(nil) +} + +// typeAndHashFromSignatureScheme returns the corresponding signature type and +// crypto.Hash for a given TLS SignatureScheme. +func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + sigType = signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + sigType = signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + sigType = signatureECDSA + case Ed25519: + sigType = signatureEd25519 + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + hash = crypto.SHA1 + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + hash = crypto.SHA256 + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + hash = crypto.SHA384 + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + hash = crypto.SHA512 + case Ed25519: + hash = directSigning + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + return sigType, hash, nil +} + +// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for +// a given public key used with TLS 1.0 and 1.1, before the introduction of +// signature algorithm negotiation. +func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { + switch pub.(type) { + case *rsa.PublicKey: + return signaturePKCS1v15, crypto.MD5SHA1, nil + case *ecdsa.PublicKey: + return signatureECDSA, crypto.SHA1, nil + case ed25519.PublicKey: + // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, + // but it requires holding on to a handshake transcript to do a + // full signature, and not even OpenSSL bothers with the + // complexity, so we can't even test it properly. + return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + default: + return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) + } +} + +var rsaSignatureSchemes = []struct { + scheme SignatureScheme + minModulusBytes int + maxVersion uint16 +}{ + // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires + // emLen >= hLen + sLen + 2 + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + // PKCS#1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // emLen >= len(prefix) + hLen + 11 + // TLS 1.3 dropped support for PKCS#1 v1.5 in favor of RSA-PSS. + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, +} + +// signatureSchemesForCertificate returns the list of supported SignatureSchemes +// for a given certificate, based on the public key and the protocol version, +// and optionally filtered by its explicit SupportedSignatureAlgorithms. +// +// This function must be kept in sync with supportedSignatureAlgorithms. +func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil + } + + var sigAlgs []SignatureScheme + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + if version != VersionTLS13 { + // In TLS 1.2 and earlier, ECDSA algorithms are not + // constrained to a single curve. + sigAlgs = []SignatureScheme{ + ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + ECDSAWithSHA1, + } + break + } + switch pub.Curve { + case elliptic.P256(): + sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return nil + } + case *rsa.PublicKey: + size := pub.Size() + sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + for _, candidate := range rsaSignatureSchemes { + if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + sigAlgs = append(sigAlgs, candidate.scheme) + } + } + case ed25519.PublicKey: + sigAlgs = []SignatureScheme{Ed25519} + default: + return nil + } + + if cert.SupportedSignatureAlgorithms != nil { + var filteredSigAlgs []SignatureScheme + for _, sigAlg := range sigAlgs { + if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { + filteredSigAlgs = append(filteredSigAlgs, sigAlg) + } + } + return filteredSigAlgs + } + return sigAlgs +} + +// selectSignatureScheme picks a SignatureScheme from the peer's preference list +// that works with the selected certificate. It's only called for protocol +// versions that support signature algorithms, so TLS 1.2 and 1.3. +func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { + supportedAlgs := signatureSchemesForCertificate(vers, c) + if len(supportedAlgs) == 0 { + return 0, unsupportedCertificateError(c) + } + if len(peerAlgs) == 0 && vers == VersionTLS12 { + // For TLS 1.2, if the client didn't send signature_algorithms then we + // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. + peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} + } + // Pick signature scheme in the peer's preference order, as our + // preference order is not configurable. + for _, preferredAlg := range peerAlgs { + if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { + return preferredAlg, nil + } + } + return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") +} + +// unsupportedCertificateError returns a helpful error for certificates with +// an unsupported private key. +func unsupportedCertificateError(cert *Certificate) error { + switch cert.PrivateKey.(type) { + case rsa.PrivateKey, ecdsa.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", + cert.PrivateKey, cert.PrivateKey) + case *ed25519.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", + cert.PrivateKey) + } + + switch pub := signer.Public().(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + case elliptic.P384(): + case elliptic.P521(): + default: + return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) + } + case *rsa.PublicKey: + return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") + case ed25519.PublicKey: + default: + return fmt.Errorf("tls: unsupported certificate key (%T)", pub) + } + + if cert.SupportedSignatureAlgorithms != nil { + return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") + } + + return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) +} diff --git a/pkg/tls/buf.go b/pkg/tls/buf.go new file mode 100644 index 000000000..b9ca83d5b --- /dev/null +++ b/pkg/tls/buf.go @@ -0,0 +1,209 @@ +package tls + +import ( + "io" + "unsafe" +) + +type MsgBuffer struct { + b []byte + l int //长度 + i int //起点位置 +} + +const ( + blocksize = 1024 * 5 //清理失效数据阈值 + appendsize = 4096 +) + +func NewBuffer(n int) *MsgBuffer { + return &MsgBuffer{b: make([]byte, 0, n)} +} + +func (w *MsgBuffer) Reset() { + w.l = 0 + w.i = 0 +} + +func (w *MsgBuffer) Make(l int) []byte { + if w.i > blocksize { + copy(w.b[:w.l-w.i], w.b[w.i:w.l]) + w.l -= w.i + w.i = 0 + } + o := w.l + w.l += l + if len(w.b) < w.l { //扩容 + if cap(w.b) < w.l { + add := w.l - len(w.b) + if add > appendsize { + w.b = append(w.b, make([]byte, add)...) + } else { + w.b = append(w.b, make([]byte, appendsize)...) + } + } + w.b = w.b[:w.l] + } + return w.b[o:w.l] +} + +func (w *MsgBuffer) Write(b []byte) (int, error) { + if w.i > blocksize { + copy(w.b[:w.l-w.i], w.b[w.i:w.l]) + w.l -= w.i + w.i = 0 + } + l := len(b) + o := w.l + w.l += l + if len(w.b) < w.l { + if cap(w.b) < w.l { + add := w.l - len(w.b) + if add > appendsize { + w.b = append(w.b, make([]byte, add)...) + } else { + w.b = append(w.b, make([]byte, appendsize)...) + } + } + w.b = w.b[:w.l] + } + copy(w.b[o:w.l], b) + return l, nil +} + +func (w *MsgBuffer) WriteString(s string) { + if w.i > blocksize { + copy(w.b[:w.l-w.i], w.b[w.i:w.l]) + w.l -= w.i + w.i = 0 + } + x := (*[2]uintptr)(unsafe.Pointer(&s)) + h := [3]uintptr{x[0], x[1], x[1]} + b := *(*[]byte)(unsafe.Pointer(&h)) + l := len(b) + o := w.l + w.l += l + if len(w.b) < w.l { //扩容 + if cap(w.b) < w.l { + add := w.l - len(w.b) + if add > appendsize { + w.b = append(w.b, make([]byte, add)...) + } else { + w.b = append(w.b, make([]byte, appendsize)...) + } + } + w.b = w.b[:w.l] + } + copy(w.b[o:w.l], b) +} + +func (w *MsgBuffer) WriteByte(s byte) error { + if w.i > blocksize { + copy(w.b[:w.l-w.i], w.b[w.i:w.l]) + w.l -= w.i + w.i = 0 + } + w.l++ + if len(w.b) < w.l { + if cap(w.b) < w.l { + add := w.l - len(w.b) + if add > appendsize { + w.b = append(w.b, make([]byte, add)...) + } else { + w.b = append(w.b, make([]byte, appendsize)...) + } + } + w.b = w.b[:w.l] + } + w.b[w.l-1] = s + + return nil +} + +func (w *MsgBuffer) Bytes() []byte { + return w.b[w.i:w.l] +} + +func (w *MsgBuffer) PreBytes(n int) []byte { + end := w.i + n + if end > w.l { + end = w.l + } + return w.b[w.i:end] +} + +func (w *MsgBuffer) Len() int { + return w.l - w.i +} + +func (w *MsgBuffer) Next(l int) []byte { + o := w.i + w.i += l + if w.i > w.l { + w.i = w.l + } + return w.b[o:w.i] +} + +func (w *MsgBuffer) Truncate(i int) { + w.l = w.i + i +} + +func (w *MsgBuffer) String() string { + b := make([]byte, w.l-w.i) + copy(b, w.b[w.i:w.l]) + return *(*string)(unsafe.Pointer(&b)) +} + +// New returns a new MsgBuffer whose buffer has the given size. +func New(size int) *MsgBuffer { + + return &MsgBuffer{ + b: make([]byte, size), + } +} + +// Shift shifts the "read" pointer. +func (r *MsgBuffer) Shift(len int) { + if len <= 0 { + return + } + if len < r.Len() { + r.i += len + if r.i > r.l { + r.i = r.l + } + } else { + r.Reset() + } +} + +func (r *MsgBuffer) Close() error { + return nil +} + +func (r *MsgBuffer) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if r.i == r.l { + return 0, io.EOF + } + o := r.i + r.i += len(p) + if r.i > r.l { + r.i = r.l + } + copy(p, r.b[o:r.i]) + return r.i - o, nil +} + +// ReadByte reads and returns the next byte from the input or ErrIsEmpty. +func (r *MsgBuffer) ReadByte() (b byte, err error) { + if r.i == r.l { + return 0, io.EOF + } + b = r.b[r.i] + r.i++ + return b, err +} diff --git a/pkg/tls/cipher_suites.go b/pkg/tls/cipher_suites.go new file mode 100644 index 000000000..62ec2bdb0 --- /dev/null +++ b/pkg/tls/cipher_suites.go @@ -0,0 +1,537 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// CipherSuite is a TLS cipher suite. Note that most functions in this package +// accept and expose cipher suite IDs instead of this type. +type CipherSuite struct { + ID uint16 + Name string + + // Supported versions is the list of TLS protocol versions that can + // negotiate this cipher suite. + SupportedVersions []uint16 + + // Insecure is true if the cipher suite has known security issues + // due to its primitives, design, or implementation. + Insecure bool +} + +var ( + supportedUpToTLS12 = []uint16{VersionTLS11, VersionTLS12} + supportedOnlyTLS12 = []uint16{VersionTLS12} + supportedOnlyTLS13 = []uint16{VersionTLS13} +) + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +// +// The list is sorted by ID. Note that the default cipher suites selected by +// this package might depend on logic that can't be captured by a static list. +func CipherSuites() []*CipherSuite { + return []*CipherSuite{ + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + + {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, + {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, + {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, + + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + } +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +// +// Most applications should not use the cipher suites in this list, and should +// only use those returned by CipherSuites. +func InsecureCipherSuites() []*CipherSuite { + // RC4 suites are broken because RC4 is. + // CBC-SHA256 suites have no Lucky13 countermeasures. + return []*CipherSuite{ + {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + } +} + +// CipherSuiteName returns the standard name for the passed cipher suite ID +// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation +// of the ID value if the cipher suite is not implemented by this package. +func CipherSuiteName(id uint16) string { + for _, c := range CipherSuites() { + if c.ID == id { + return c.Name + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c.Name + } + } + return fmt.Sprintf("0x%04X", id) +} + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +const ( + // suiteECDHE indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECSign indicates that the cipher suite involves an ECDSA or + // EdDSA signature and therefore may only be selected when the server's + // certificate is ECDSA or EdDSA. If this is not set then the cipher suite + // is RSA based. + suiteECSign + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 + // suiteDefaultOff indicates that this cipher suite is not included by + // default. + suiteDefaultOff +) + +// A cipherSuite is a specific combination of key agreement, cipher and MAC function. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) interface{} + mac func(version uint16, macKey []byte) macFunction + aead func(key, fixedNonce []byte) aead +} + +var cipherSuites = []*cipherSuite{ + // Ciphersuite order is chosen so that ECDHE comes before plain RSA and + // AEADs are the top preference. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + + // RC4-based cipher suites are disabled by default. + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteDefaultOff, cipherRC4, macSHA1, nil}, +} + +// selectCipherSuite returns the first cipher suite from ids which is also in +// supportedIDs and passes the ok filter. +func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { + for _, id := range ids { + candidate := cipherSuiteByID(id) + if candidate == nil || !ok(candidate) { + continue + } + + for _, suppID := range supportedIDs { + if id == suppID { + return candidate + } + } + } + return nil +} + +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + +func cipherRC4(key, iv []byte, isRead bool) interface{} { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) interface{} { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) interface{} { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a macFunction for the given protocol version. +func macSHA1(version uint16, key []byte) macFunction { + return tls10MAC{h: hmac.New(newConstantTimeHash(sha1.New), key)} +} + +// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 +// so the given version is ignored. +func macSHA256(version uint16, key []byte) macFunction { + return tls10MAC{h: hmac.New(sha256.New, key)} +} + +type macFunction interface { + // Size returns the length of the MAC. + Size() int + // MAC appends the MAC of (seq, header, data) to out. The extra data is fed + // into the MAC after obtaining the result to normalize timing. The result + // is only valid until the next invocation of MAC as the buffer is reused. + MAC(seq, header, data, extra []byte) []byte +} + +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type prefixNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } + +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) + return ret +} + +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. +type tls10MAC struct { + h hash.Hash + buf []byte +} + +func (s tls10MAC) Size() int { + return s.h.Size() +} + +// MAC is guaranteed to take constant time, as long as +// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into +// the MAC, but is only provided to make the timing profile constant. +func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte { + s.h.Reset() + s.h.Write(seq) + s.h.Write(header) + s.h.Write(data) + res := s.h.Sum(s.buf[:0]) + if extra != nil { + s.h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + return cipherSuiteByID(id) + } + } + return nil +} + +func cipherSuiteByID(id uint16) *cipherSuite { + for _, cipherSuite := range cipherSuites { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { + for _, id := range have { + if id == want { + return cipherSuiteTLS13ByID(id) + } + } + return nil +} + +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { + for _, cipherSuite := range cipherSuitesTLS13 { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) diff --git a/pkg/tls/common.go b/pkg/tls/common.go new file mode 100644 index 000000000..cf9e2235d --- /dev/null +++ b/pkg/tls/common.go @@ -0,0 +1,1450 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "container/list" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "crypto/x509" + "errors" + "fmt" + "io" + "strings" + "sync" + "time" + + "golang.org/x/sys/cpu" +) + +const ( + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = 0x0300 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +// CurveID is the type of a TLS identifier for an elliptic curve. See +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8. +// +// In TLS 1.3, this type is called NamedGroup, but at this time this library +// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7. +type CurveID uint16 + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) + +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + +// TLS Elliptic Curve Point Formats +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 + certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. +) + +// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with +// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. +const ( + signaturePKCS1v15 uint8 = iota + 225 + signatureRSAPSS + signatureECDSA + signatureEd25519 +) + +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed, and that the input should be signed directly. It is the +// hash function associated with the Ed25519 signature scheme. +var directSigning crypto.Hash = 0 + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var supportedSignatureAlgorithms = []SignatureScheme{ + PSSWithSHA256, + ECDSAWithP256AndSHA256, + Ed25519, + PSSWithSHA384, + PSSWithSHA512, + PKCS1WithSHA256, + PKCS1WithSHA384, + PKCS1WithSHA512, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// helloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +// testingOnlyForceDowngradeCanary is set in tests to force the server side to +// include downgrade canaries even if it's using its highers supported version. +var testingOnlyForceDowngradeCanary bool + +// ConnectionState records basic TLS details about the connection. +type ConnectionState struct { + Version uint16 // TLS version used by the connection (e.g. VersionTLS12) + HandshakeComplete bool // TLS handshake is complete + DidResume bool // connection resumes a previous TLS connection + CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...) + NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) + NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) + ServerName string // server name requested by client, if any + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates + SignedCertificateTimestamps [][]byte // SCTs from the peer, if any + OCSPResponse []byte // stapled OCSP response from peer, if any + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) + + // TLSUnique contains the "tls-unique" channel binding value (see RFC + // 5929, section 3). For resumed sessions this value will be nil + // because resumption does not include enough context (see + // https://mitls.org/pages/attacks/3SHAKE#channelbindings). This will + // change in future versions of Go once the TLS master-secret fix has + // been standardized and implemented. It is not defined in TLS 1.3. + TLSUnique []byte +} + +// ExportKeyingMaterial returns length bytes of exported key material in a new +// slice as defined in RFC 5705. If context is nil, it is not used as part of +// the seed. If the connection was set to allow renegotiation via +// Config.Renegotiation, this function will return an error. +func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return cs.ekm(label, context, length) +} + +// ClientAuthType declares the policy the server will follow for +// TLS Client Authentication. +type ClientAuthType int + +const ( + NoClientCert ClientAuthType = iota + RequestClientCert + RequireAnyClientCert + VerifyClientCertIfGiven + RequireAndVerifyClientCert +) + +// requiresClientCert reports whether the ClientAuthType requires a client +// certificate to be provided. +func requiresClientCert(c ClientAuthType) bool { + switch c { + case RequireAnyClientCert, RequireAndVerifyClientCert: + return true + default: + return false + } +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not +// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which +// are supported via this interface. +type ClientSessionCache interface { + // Get searches for a ClientSessionState associated with the given key. + // On return, ok is true if one was found. + Get(sessionKey string) (session *ClientSessionState, ok bool) + + // Put adds the ClientSessionState to the cache with the given key. It might + // get called multiple times in a connection if a TLS 1.3 server provides + // more than one session ticket. If called with a nil *ClientSessionState, + // it should remove the cache entry. + Put(sessionKey string, cs *ClientSessionState) +} + +//go:generate stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go + +// SignatureScheme identifies a signature algorithm supported by TLS. See +// RFC 8446, Section 4.2.3. +type SignatureScheme uint16 + +const ( + // RSASSA-PKCS1-v1_5 algorithms. + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + // RSASSA-PSS algorithms with public key OID rsaEncryption. + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 + + // EdDSA algorithms. + Ed25519 SignatureScheme = 0x0807 + + // Legacy signature and hash algorithms for TLS 1.2. + PKCS1WithSHA1 SignatureScheme = 0x0201 + ECDSAWithSHA1 SignatureScheme = 0x0203 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate and GetConfigForClient callbacks. +type ClientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see RFC 4492, Section 5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see RFC 4492, Section 5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see RFC 7301, Section 3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn conn + + // config is embedded by the GetCertificate or GetConfigForClient caller, + // for use with SupportsCertificate. + config *Config +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme + + // Version is the TLS version that was negotiated for this connection. + Version uint16 +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +// +// Renegotiation is not defined in TLS 1.3. +type RenegotiationSupport int + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever RenegotiationSupport = iota + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to the + // other side of the connection. The first certificate compatible with the + // peer's requirements is selected automatically. + // + // Server configurations must set one of Certificates, GetCertificate or + // GetConfigForClient. Clients doing client-authentication may set either + // Certificates or GetClientCertificate. + // + // Note: if there are multiple Certificates, and they don't have the + // optional field Leaf set, certificate selection will incur a significant + // per-handshake performance cost. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // + // Deprecated: NameToCertificate only allows associating a single + // certificate with a given name. Leave this field nil to let the library + // select the first compatible chain from Certificates. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // If SessionTicketKey was explicitly set on the returned Config, or if + // SetSessionTicketKeys was called on the returned Config, those keys will + // be used. Otherwise, the original Config keys will be used (and possibly + // rotated if they are automatically managed). + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported application level protocols, in + // order of preference. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // CipherSuites is a list of supported cipher suites for TLS versions up to + // TLS 1.2. If CipherSuites is nil, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. The + // default cipher suites might change over Go versions. Note that TLS 1.3 + // ciphersuites are not configurable. + CipherSuites []uint16 + + // PreferServerCipherSuites controls whether the server selects the + // client's most preferred ciphersuite, or the server's most preferred + // ciphersuite. If true then the server's preference, as expressed in + // the order of elements in CipherSuites, is used. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket and + // PSK (resumption) support. Note that on clients, session ticket support is + // also disabled if ClientSessionCache is nil. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session resumption. + // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled + // with random data before the first server handshake. + // + // Deprecated: if this field is left at zero, session ticket keys will be + // automatically rotated every day and dropped after seven days. For + // customizing the rotation schedule or synchronizing servers that are + // terminating connections for the same host, use SetSessionTicketKeys. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. It is only used by clients. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum TLS version that is acceptable. + // If zero, TLS 1.0 is currently taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum TLS version that is acceptable. + // If zero, the maximum version supported by this package is used, + // which is currently TLS 1.3. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. The client will use the first preference as the type for + // its key share in TLS 1.3. This may change in the future. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // mutex protects sessionTicketKeys and autoSessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // the keys were set with SessionTicketKey or SetSessionTicketKeys. The + // first key is used for new tickets and any subsequent keys can be used to + // decrypt old tickets. The slice contents are not protected by the mutex + // and are immutable. + sessionTicketKeys []ticketKey + // autoSessionTicketKeys is like sessionTicketKeys but is owned by the + // auto-rotation logic. See Config.ticketKeys. + autoSessionTicketKeys []ticketKey +} + +const ( + // ticketKeyNameLen is the number of bytes of identifier that is prepended to + // an encrypted session ticket in order to identify the key used to encrypt it. + ticketKeyNameLen = 16 + + // ticketKeyLifetime is how long a ticket key remains valid and can be used to + // resume a client connection. + ticketKeyLifetime = 7 * 24 * time.Hour // 7 days + + // ticketKeyRotation is how often the server should rotate the session ticket key + // that is used for new tickets. + ticketKeyRotation = 24 * time.Hour +) + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte + // created is the time at which this ticket key was created. See Config.ticketKeys. + created time.Time +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + key.created = c.time() + return key +} + +// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session +// ticket, and the lifetime we set for tickets we send. +const maxSessionTicketLifetime = 7 * 24 * time.Hour + +// Clone returns a shallow clone of c. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return &Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: c.sessionTicketKeys, + autoSessionTicketKeys: c.autoSessionTicketKeys, + } +} + +// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was +// randomized for backwards compatibility but is not in use. +var deprecatedSessionTicketKey = []byte("DEPRECATED") + +// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is +// randomized if empty, and that sessionTicketKeys is populated from it otherwise. +func (c *Config) initLegacySessionTicketKeyRLocked() { + // Don't write if SessionTicketKey is already defined as our deprecated string, + // or if it is defined by the user but sessionTicketKeys is already set. + if c.SessionTicketKey != [32]byte{} && + (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { + return + } + + // We need to write some data, so get an exclusive lock and re-check any conditions. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + if c.SessionTicketKey == [32]byte{} { + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) + } + // Write the deprecated prefix at the beginning so we know we created + // it. This key with the DEPRECATED prefix isn't used as an actual + // session ticket key, and is only randomized in case the application + // reuses it for some reason. + copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) + } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { + c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} + } + +} + +// ticketKeys returns the ticketKeys for this connection. +// If configForClient has explicitly set keys, those will +// be returned. Otherwise, the keys on c will be used and +// may be rotated if auto-managed. +// During rotation, any expired session ticket keys are deleted from +// c.sessionTicketKeys. If the session ticket key that is currently +// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) +// is not fresh, then a new session ticket key will be +// created and prepended to c.sessionTicketKeys. +func (c *Config) ticketKeys(configForClient *Config) []ticketKey { + // If the ConfigForClient callback returned a Config with explicitly set + // keys, use those, otherwise just use the original Config. + if configForClient != nil { + configForClient.mutex.RLock() + if configForClient.SessionTicketsDisabled { + return nil + } + configForClient.initLegacySessionTicketKeyRLocked() + if len(configForClient.sessionTicketKeys) != 0 { + ret := configForClient.sessionTicketKeys + configForClient.mutex.RUnlock() + return ret + } + configForClient.mutex.RUnlock() + } + + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.SessionTicketsDisabled { + return nil + } + c.initLegacySessionTicketKeyRLocked() + if len(c.sessionTicketKeys) != 0 { + return c.sessionTicketKeys + } + // Fast path for the common case where the key is fresh enough. + if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { + return c.autoSessionTicketKeys + } + + // autoSessionTicketKeys are managed by auto-rotation. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + // Re-check the condition in case it changed since obtaining the new lock. + if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { + var newKey [32]byte + if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { + panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) + } + valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) + valid = append(valid, c.ticketKeyFromBytes(newKey)) + for _, k := range c.autoSessionTicketKeys { + // While rotating the current key, also remove any expired ones. + if c.time().Sub(k.created) < ticketKeyLifetime { + valid = append(valid, k) + } + } + c.autoSessionTicketKeys = valid + } + return c.autoSessionTicketKeys +} + +// SetSessionTicketKeys updates the session ticket keys for a server. +// +// The first key will be used when creating new tickets, while all keys can be +// used for decrypting tickets. It is safe to call this function while the +// server is running in order to rotate the session ticket keys. The function +// will panic if keys is empty. +// +// Calling this function will turn off automatic session ticket key rotation. +// +// If multiple servers are terminating connections for the same host they should +// all have the same session ticket keys. If the session ticket keys leaks, +// previously recorded and future TLS connections using those keys might be +// compromised. +func (c *Config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = c.ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *Config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *Config) cipherSuites() []uint16 { + s := c.CipherSuites + if s == nil { + s = defaultCipherSuites() + } + return s +} + +var supportedVersions = []uint16{ + VersionTLS13, + VersionTLS12, + VersionTLS11, + //VersionTLS10, +} + +func (c *Config) supportedVersions() []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +func (c *Config) maxSupportedVersion() uint16 { + supportedVersions := c.supportedVersions() + if len(supportedVersions) == 0 { + return 0 + } + return supportedVersions[0] +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *Config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +func (c *Config) supportsCurve(curve CurveID) bool { + for _, cc := range c.curvePreferences() { + if cc == curve { + return true + } + } + return false +} + +// mutualVersion returns the protocol version to use given the advertised +// versions of the peer. Priority is given to the peer preference order. +func (c *Config) mutualVersion(peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions() + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } + } + return 0, false +} + +var errNoCertificates = errors.New("tls: no certificates configured") + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errNoCertificates + } + + if len(c.Certificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + if c.NameToCertificate != nil { + name := strings.ToLower(clientHello.ServerName) + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[wildcardName]; ok { + return cert, nil + } + } + } + + for _, cert := range c.Certificates { + if err := clientHello.SupportsCertificate(&cert); err == nil { + return &cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the client that sent the ClientHello. Otherwise, it returns an error +// describing the reason for the incompatibility. +// +// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate +// callback, this method will take into account the associated Config. Note that +// if GetConfigForClient returns a different Config, the change can't be +// accounted for by this method. +// +// This function will call x509.ParseCertificate unless c.Leaf is set, which can +// incur a significant performance cost. +func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error { + // Note we don't currently support certificate_authorities nor + // signature_algorithms_cert, and don't check the algorithms of the + // signatures on the chain (which anyway are a SHOULD, see RFC 8446, + // Section 4.4.2.2). + + config := chi.config + if config == nil { + config = &Config{} + } + vers, ok := config.mutualVersion(chi.SupportedVersions) + if !ok { + return errors.New("no mutually supported protocol versions") + } + + // If the client specified the name they are trying to connect to, the + // certificate needs to be valid for it. + if chi.ServerName != "" { + x509Cert, err := c.leaf() + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { + return fmt.Errorf("certificate is not valid for requested server name: %w", err) + } + } + + // supportsRSAFallback returns nil if the certificate and connection support + // the static RSA key exchange, and unsupported otherwise. The logic for + // supporting static RSA is completely disjoint from the logic for + // supporting signed key exchanges, so we just check it as a fallback. + supportsRSAFallback := func(unsupported error) error { + // TLS 1.3 dropped support for the static RSA key exchange. + if vers == VersionTLS13 { + return unsupported + } + // The static RSA key exchange works by decrypting a challenge with the + // RSA private key, not by signing, so check the PrivateKey implements + // crypto.Decrypter, like *rsa.PrivateKey does. + if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { + if _, ok := priv.Public().(*rsa.PublicKey); !ok { + return unsupported + } + } else { + return unsupported + } + // Finally, there needs to be a mutual cipher suite that uses the static + // RSA key exchange instead of ECDHE. + rsaCipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + return false + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if rsaCipherSuite == nil { + return unsupported + } + return nil + } + + // If the client sent the signature_algorithms extension, ensure it supports + // schemes we can use with this certificate and TLS version. + if len(chi.SignatureSchemes) > 0 { + if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { + return supportsRSAFallback(err) + } + } + + // In TLS 1.3 we are done because supported_groups is only relevant to the + // ECDHE computation, point format negotiation is removed, cipher suites are + // only relevant to the AEAD choice, and static RSA does not exist. + if vers == VersionTLS13 { + return nil + } + + // The only signed key exchange we support is ECDHE. + if !supportsECDHE(config, chi.SupportedCurves, chi.SupportedPoints) { + return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) + } + + var ecdsaCipherSuite bool + if priv, ok := c.PrivateKey.(crypto.Signer); ok { + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + var curve CurveID + switch pub.Curve { + case elliptic.P256(): + curve = CurveP256 + case elliptic.P384(): + curve = CurveP384 + case elliptic.P521(): + curve = CurveP521 + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + var curveOk bool + for _, c := range chi.SupportedCurves { + if c == curve && config.supportsCurve(c) { + curveOk = true + break + } + } + if !curveOk { + return errors.New("client doesn't support certificate curve") + } + ecdsaCipherSuite = true + case ed25519.PublicKey: + if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { + return errors.New("connection doesn't support Ed25519") + } + ecdsaCipherSuite = true + case *rsa.PublicKey: + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + } else { + return supportsRSAFallback(unsupportedCertificateError(c)) + } + + // Make sure that there is a mutually supported cipher suite that works with + // this certificate. Cipher suite selection will then apply the logic in + // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. + cipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE == 0 { + return false + } + if c.flags&suiteECSign != 0 { + if !ecdsaCipherSuite { + return false + } + } else { + if ecdsaCipherSuite { + return false + } + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if cipherSuite == nil { + return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) + } + + return nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the server that sent the CertificateRequest. Otherwise, it returns an error +// describing the reason for the incompatibility. +func (cri *CertificateRequestInfo) SupportsCertificate(c *Certificate) error { + if _, err := selectSignatureScheme(cri.Version, c, cri.SignatureSchemes); err != nil { + return err + } + + if len(cri.AcceptableCAs) == 0 { + return nil + } + + for j, cert := range c.Certificate { + x509Cert := c.Leaf + // Parse the certificate if this isn't the leaf node, or if + // chain.Leaf was nil. + if j != 0 || x509Cert == nil { + var err error + if x509Cert, err = x509.ParseCertificate(cert); err != nil { + return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) + } + } + + for _, ca := range cri.AcceptableCAs { + if bytes.Equal(x509Cert.RawIssuer, ca) { + return nil + } + } + } + return errors.New("chain is not signed by an acceptable CA") +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +// +// Deprecated: NameToCertificate only allows associating a single certificate +// with a given name. Leave that field nil to let the library select the first +// compatible chain from Certificates. +func (c *Config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := cert.leaf() + if err != nil { + continue + } + if len(x509Cert.Subject.CommonName) > 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +const ( + keyLogLabelTLS12 = "CLIENT_RANDOM" + keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" + keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" +) + +func (c *Config) writeKeyLog(label string, clientRandom, secret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate struct { + Certificate [][]byte + // PrivateKey contains the private key corresponding to the public key in + // Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey. + // For a server up to TLS 1.2, it can also implement crypto.Decrypter with + // an RSA PublicKey. + PrivateKey crypto.PrivateKey + // SupportedSignatureAlgorithms is an optional list restricting what + // signature algorithms the PrivateKey can be used for. + SupportedSignatureAlgorithms []SignatureScheme + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte + // SignedCertificateTimestamps contains an optional list of Signed + // Certificate Timestamps which will be served to clients that request it. + SignedCertificateTimestamps [][]byte + // Leaf is the parsed form of the leaf certificate, which may be initialized + // using x509.ParseCertificate to reduce per-handshake processing. If nil, + // the leaf certificate will be parsed as needed. + Leaf *x509.Certificate +} + +// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing +// the corresponding c.Certificate[0]. +func (c *Certificate) leaf() (*x509.Certificate, error) { + if c.Leaf != nil { + return c.Leaf, nil + } + return x509.ParseCertificate(c.Certificate[0]) +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry +// corresponding to sessionKey is removed from the cache instead. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + if cs == nil { + c.q.Remove(elem) + delete(c.m, sessionKey) + } else { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + } + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +var ( + once sync.Once + varDefaultCipherSuites []uint16 + varDefaultCipherSuitesTLS13 []uint16 +) + +func defaultCipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuites +} + +func defaultCipherSuitesTLS13() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuitesTLS13 +} + +func initDefaultCipherSuites() { + var topCipherSuites []uint16 + + // Check the cpu flags for each platform that has optimized GCM implementations. + // Worst case, these variables will just all be false. + var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasGCMAsm = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + ) + + if hasGCMAsm { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + } + + varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) + varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) + +NextCipherSuite: + for _, suite := range cipherSuites { + if suite.flags&suiteDefaultOff != 0 { + continue + } + for _, existing := range varDefaultCipherSuites { + if existing == suite.id { + continue NextCipherSuite + } + } + varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) + } +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go new file mode 100644 index 000000000..5c588499f --- /dev/null +++ b/pkg/tls/conn.go @@ -0,0 +1,1260 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package tls + +import ( + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "sync" + + "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn conn + isClient bool + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // This field is only to be accessed with sync/atomic. + handshakeStatus uint8 + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // resumptionSecret is the resumption_master_secret for handling + // NewSessionTicket messages. nil if config.SessionTicketsDisabled. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []ticketKey + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + clientProtocol string + clientProtocolFallback bool + + // input/output + in, out halfConn + rawInput MsgBuffer // raw input, starting with a record header + input *elastic.RingBuffer // a buffer for decrypted records + // pointer to the inboundBuffer of gnet.conn + hand MsgBuffer // handshake data waiting to be read + outBuf []byte // scratch buffer used by out.encrypt + buffering bool // whether records are buffered in sendBuf + sendBuf *elastic.Buffer // a buffer for records waiting to be sent + // also point to the outboundBuffer of gnet.conn + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // retryCount counts the number of consecutive non-advancing records + // received by Conn.readRecord. That is, records that neither advance the + // handshake, nor deliver application data. Protected by in.Mutex. + retryCount int + + tmp [16]byte + hs interface { + handshake() error + } +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac macFunction + seq [8]byte // 64-bit sequence number + additionalData [13]byte // to avoid allocs; interface method args escape + + nextCipher interface{} // next encryption state + nextMac macFunction // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret +} + +type permamentError struct { + err net.Error +} + +func (e *permamentError) Error() string { return e.err.Error() } +func (e *permamentError) Unwrap() error { return e.err } +func (e *permamentError) Timeout() bool { return e.err.Timeout() } +func (e *permamentError) Temporary() bool { return false } + +func (hc *halfConn) setErrorLocked(err error) error { + if e, ok := err.(net.Error); ok { + hc.err = &permamentError{err: e} + } else { + hc.err = err + } + return hc.err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil || hc.version == VersionTLS13 { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { + hc.trafficSecret = secret + key, iv := suite.trafficKey(secret) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// explicitNonceLen returns the number of bytes of explicit nonce or IV included +// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 +// and in certain AEAD modes in TLS 1.2. +func (hc *halfConn) explicitNonceLen() int { + if hc.cipher == nil { + return 0 + } + + switch c := hc.cipher.(type) { + case cipher.Stream: + return 0 + case aead: + return c.explicitNonceLen() + case cbcMode: + // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. + if hc.version >= VersionTLS11 { + return c.BlockSize() + } + return 0 + default: + panic("unknown cipher type") + } +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + // Zero the padding length on error. This ensures any unchecked bytes + // are included in the MAC. Otherwise, an attacker that could + // distinguish MAC failures from padding failures could mount an attack + // similar to POODLE in SSL 3.0: given a good ciphertext that uses a + // full block's worth of padding, replace the final block with another + // block. If the MAC check passed but the padding check failed, the + // last byte of that block decrypted to the block size. + // + // See also macAndPaddingGood logic below. + paddingLen &= good + + toRemove = int(paddingLen) + 1 + return +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt authenticates and decrypts the record if protection is active at +// this stage. The returned plaintext might overlap with the input. +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) + payload := record[recordHeaderLen:] + + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + + paddingGood := byte(255) + paddingLen := 0 + + explicitNonceLen := hc.explicitNonceLen() + + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + if len(payload) < explicitNonceLen { + return nil, 0, alertBadRecordMAC + } + nonce := payload[:explicitNonceLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + payload = payload[explicitNonceLen:] + + additionalData := hc.additionalData[:] + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + copy(additionalData, hc.seq[:]) + copy(additionalData[8:], record[:3]) + n := len(payload) - c.Overhead() + additionalData[11] = byte(n >> 8) + additionalData[12] = byte(n) + } + + var err error + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return nil, 0, alertBadRecordMAC + } + case cbcMode: + blockSize := c.BlockSize() + minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) + if len(payload)%blockSize != 0 || len(payload) < minPayload { + return nil, 0, alertBadRecordMAC + } + + if explicitNonceLen > 0 { + c.SetIV(payload[:explicitNonceLen]) + payload = payload[explicitNonceLen:] + } + c.CryptBlocks(payload, payload) + + // In a limited attempt to protect against CBC padding oracles like + // Lucky13, the data past paddingLen (which is secret) is passed to + // the MAC function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC roughly constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write, modulo cache effects. + paddingLen, paddingGood = extractPadding(payload) + default: + panic("unknown cipher type") + } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } + } else { + plaintext = payload + } + + if hc.mac != nil { + macSize := hc.mac.Size() + if len(payload) < macSize { + return nil, 0, alertBadRecordMAC + } + + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + record[3] = byte(n >> 8) + record[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + // This is equivalent to checking the MACs and paddingGood + // separately, but in constant-time to prevent distinguishing + // padding failures from MAC failures. Depending on what value + // of paddingLen was returned on bad padding, distinguishing + // bad MAC from bad padding can lead to an attack. + // + // See also the logic at the end of extractPadding. + macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) + if macAndPaddingGood != 1 { + return nil, 0, alertBadRecordMAC + } + + plaintext = payload[:n] + } + + hc.incSeq() + return plaintext, typ, nil +} +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and +// appends it to record, which contains the record header. +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + if hc.cipher == nil { + return append(record, payload...), nil + } + + var explicitNonce []byte + if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { + record, explicitNonce = sliceForAppend(record, explicitNonceLen) + if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { + // The AES-GCM construction in TLS has an explicit nonce so that the + // nonce can be random. However, the nonce is only 8 bytes which is + // too small for a secure, random nonce. Therefore we use the + // sequence number as the nonce. The 3DES-CBC construction also has + // an 8 bytes nonce but its nonces must be unpredictable (see RFC + // 5246, Appendix F.3), forcing us to use randomness. That's not + // 3DES' biggest problem anyway because the birthday bound on block + // collision is reached first due to its simlarly small block size + // (see the Sweet32 attack). + copy(explicitNonce, hc.seq[:]) + } else { + if _, err := io.ReadFull(rand, explicitNonce); err != nil { + return nil, err + } + } + } + + var mac []byte + if hc.mac != nil { + mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil) + } + + var dst []byte + switch c := hc.cipher.(type) { + case cipher.Stream: + record, dst = sliceForAppend(record, len(payload)+len(mac)) + c.XORKeyStream(dst[:len(payload)], payload) + c.XORKeyStream(dst[len(payload):], mac) + case aead: + nonce := explicitNonce + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + if hc.version == VersionTLS13 { + record = append(record, payload...) + + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + copy(hc.additionalData[:], hc.seq[:]) + copy(hc.additionalData[8:], record) + record = c.Seal(record, nonce, payload, hc.additionalData[:]) + } + case cbcMode: + blockSize := c.BlockSize() + plaintextLen := len(payload) + len(mac) + paddingLen := blockSize - plaintextLen%blockSize + record, dst = sliceForAppend(record, plaintextLen+paddingLen) + copy(dst, payload) + copy(dst[len(payload):], mac) + for i := plaintextLen; i < len(dst); i++ { + dst[i] = byte(paddingLen - 1) + } + if len(explicitNonce) > 0 { + c.SetIV(explicitNonce) + } + c.CryptBlocks(dst, dst) + default: + panic("unknown cipher type") + } + + // Update length to include nonce, MAC and any block padding needed. + n := len(record) - recordHeaderLen + record[3] = byte(n >> 8) + record[4] = byte(n) + hc.incSeq() + + return record, nil +} + +// RecordHeaderError is returned when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn conn +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(conn conn, msg string) (err RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawInput.Bytes()) + return err +} + +func (c *Conn) readRecord() error { + if c.rawInput.Len() > 5 { + return c.readRecordOrCCS(false) + } + return io.EOF +} + +func (c *Conn) readChangeCipherSpec() error { + c.input.Reset() + return c.readRecordOrCCS(true) +} + +// readRecordOrCCS reads one or more TLS records from the connection and +// updates the record layer state. Some invariants: +// * c.in must be locked +// * c.input must be empty +// During the handshake one and only one of the following will happen: +// - c.hand grows +// - c.in.changeCipherSpec is called +// - an error is returned +// After the handshake one and only one of the following will happen: +// - c.hand grows +// - c.input is set +// - an error is returned +func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { + if c.in.err != nil { + return c.in.err + } + + hdr := c.rawInput.Bytes() + typ := recordType(hdr[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + n := int(hdr[3])<<8 | int(hdr[4]) + if len(hdr) < recordHeaderLen+n { + return io.EOF + } + // Read header, payload. + if c.handshakeStatus != 255 && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } + if c.haveVers { + if c.vers != VersionTLS13 && vers != c.vers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + } else { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + } + } + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + + // Process message. + c.rawInput.Shift(recordHeaderLen + n) + data, typ, err := c.in.decrypt(hdr[:recordHeaderLen+n]) + + if err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + if len(data) > maxPlaintext { + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if alert(data[1]) == alertCloseNotify { + return c.in.setErrorLocked(io.EOF) + } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord(expectChangeCipherSpec) + case alertLevelError: + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 { + return c.retryReadRecord(expectChangeCipherSpec) + } + if !expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if err := c.in.changeCipherSpec(); err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if c.handshakeStatus != 255 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord(expectChangeCipherSpec) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.input.Write(data) + + case recordTypeHandshake: + if len(data) == 0 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hand.Write(data) + } + + return nil +} + +// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + c.input.Reset() + if c.rawInput.Len() > 5 { + return c.readRecordOrCCS(expectChangeCipherSpec) + } + return io.EOF +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err alert) error { + + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= c.out.mac.Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size() + default: + panic("unknown cipher type") + } + } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +func (c *Conn) write(data []byte) (n int, err error) { + //必须把所有数据往buf写 + n = len(data) + c.sendBuf.Write(data) + c.bytesSent += int64(n) + return +} + +func (c *Conn) flush() (int, error) { + if c.sendBuf.Buffered() == 0 { + return 0, nil + } + n, err := c.conn.Write(nil) + c.bytesSent += int64(n) + c.buffering = false + return n, err +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (n int, err error) { + + for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + + _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen) + c.outBuf[0] = byte(typ) + /*vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 + } + c.outBuf[1] = byte(vers >> 8) + c.outBuf[2] = byte(vers)*/ + c.outBuf[3] = byte(m >> 8) + c.outBuf[4] = byte(m) + + c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand()) + if err != nil { + return n, err + } + if _, err = c.write(c.outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { + if err = c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +func (c *Conn) readHandshake() (interface{}, error) { + for c.hand.Len() < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + + data := c.hand.PreBytes(4) + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + for c.hand.Len() < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + data = c.hand.Next(4 + n) + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeNewSessionTicket: + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } + case typeCertificate: + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + case typeFinished: + m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +var ( + errClosed = errors.New("tls: use of closed connection") + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +func (c *Conn) Write(b []byte) error { + // interlock with Close below + + if c.handshakeStatus != 255 { + return nil + } + + c.buffering = false + + if err := c.out.err; err != nil { + return err + } + + if c.closeNotifySent { + return errShutdown + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + _, err := c.writeRecordLocked(recordTypeApplicationData, b) + return c.out.setErrorLocked(err) +} + +// load the data into the TLS rawInput +func (c *Conn) RawWrite(data []byte) (int, error) { + + c.rawInput.Write(data) + return len(data), nil +} + +// Decrypt one tls record and save it in the 解析一条tls数据 +func (c *Conn) ReadFrame() error { + if c.rawInput.Len() > 5 { + return c.readRecordOrCCS(false) + } + return io.EOF +} + +func (c *Conn) RawData() []byte { + return c.rawInput.Bytes() +} + +// Close closes the connection. +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if c.handshakeStatus != 255 { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + if !c.closeNotifySent { + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// Most uses of this package need not call Handshake +// explicitly: the first Read or Write will call it automatically. +func (c *Conn) Handshake() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeStatus == 255 { + return nil + } + + if c.isClient { + c.handshakeErr = c.clientHandshake() + } else { + c.handshakeErr = c.serverHandshake() + } + + if c.handshakeErr == io.EOF { + c.handshakeErr = nil + } + if c.handshakeErr == nil { + c.handshakes++ + } else { + //panic(c.handshakeErr) + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + var state ConnectionState + state.HandshakeComplete = c.handshakeStatus == 255 + state.ServerName = c.serverName + + if state.HandshakeComplete { + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + } + + return state +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + c.handshakeStatus = 0 + if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + default: + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) + } +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil { + return c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + c.in.setTrafficSecret(cipherSuite, newSecret) + + if keyUpdate.updateRequested { + + msg := &keyUpdateMsg{} + _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + if err != nil { + // Surface the error at the next write. + c.out.setErrorLocked(err) + return nil + } + + newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) + c.out.setTrafficSecret(cipherSuite, newSecret) + } + + return nil +} +func (c *Conn) connectionStateLocked() ConnectionState { + var state ConnectionState + state.HandshakeComplete = c.handshakeStatus == 255 + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.ServerName = c.serverName + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + return state +} +func (c *Conn) HandshakeComplete() bool { + return c.handshakeStatus == 255 +} diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go new file mode 100644 index 000000000..f3f3d043b --- /dev/null +++ b/pkg/tls/handshake_client.go @@ -0,0 +1,1034 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" +) + +type clientHandshakeState struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *ClientSessionState + oldsession *ClientSessionState + cacheKey string +} + +func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { + config := c.config + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return nil, nil, errors.New("tls: NextProtos values too large") + } + + supportedVersions := config.supportedVersions() + if len(supportedVersions) == 0 { + return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + + clientHelloVersion := config.maxSupportedVersion() + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + + hello := &clientHelloMsg{ + vers: clientHelloVersion, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + sessionId: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + secureRenegotiationSupported: true, + alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + possibleCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) + + for _, suiteId := range possibleCipherSuites { + for _, suite := range cipherSuites { + if suite.id != suiteId { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + break + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + break + } + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + var params ecdheParameters + if hello.supportedVersions[0] == VersionTLS13 { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...) + + curveID := config.curvePreferences()[0] + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err = generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + return hello, params, nil +} + +func (c *Conn) clientHandshake() (err error) { + + switch c.handshakeStatus { + case 0: + if c.config == nil { + c.config = defaultConfig() + } + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheParams, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + if cacheKey != "" && session != nil { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } + c.flush() + c.handshakeStatus = 1 //已发送hello,等待下一个数据包 + + c.hs = &clientHandshakeStateTLS13{ //临时缓存 + c: c, + hello: hello, + ecdheParams: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + cacheKey: cacheKey, + } + + case 1: + hello := c.hs.(*clientHandshakeStateTLS13).hello + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + c.handshakeStatus = 2 + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + if c.vers == VersionTLS13 { + c.hs.(*clientHandshakeStateTLS13).serverHello = serverHello + // In TLS 1.3, session tickets are delivered after the handshake. + return c.hs.handshake() + } + hs := &clientHandshakeState{ + c: c, + serverHello: serverHello, + hello: hello, + session: c.hs.(*clientHandshakeStateTLS13).session, + oldsession: c.hs.(*clientHandshakeStateTLS13).session, + cacheKey: c.hs.(*clientHandshakeStateTLS13).cacheKey, + } + c.hs = hs + if err := hs.handshake(); err != nil { + return err + } + case 3, 4, 5: + c.hs.handshake() + default: + return errors.New("tls handshakeStatus error:" + strconv.Itoa(int(c.handshakeStatus))) + } + return nil +} + +func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *ClientSessionState, earlySecret, binderKey []byte) { + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return "", nil, nil, nil + } + + hello.ticketSupported = true + + if hello.supportedVersions[0] == VersionTLS13 { + // Require DHE on resumption as it guarantees forward secrecy against + // compromise of the session ticket key. See RFC 8446, Section 4.2.9. + hello.pskModes = []uint8{pskModeDHE} + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if c.handshakes != 0 { + return "", nil, nil, nil + } + + // Try to resume a previously negotiated TLS session, if available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + session, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || session == nil { + return cacheKey, nil, nil, nil + } + + // Check that version used for the previous session is still valid. + versOk := false + + for _, v := range hello.supportedVersions { + if v == session.vers { + versOk = true + break + } + } + if !versOk { + return cacheKey, nil, nil, nil + } + + // Check that the cached server certificate is not expired, and that it's + // valid for the ServerName. This should be ensured by the cache key, but + // protect the application from a faulty ClientSessionCache implementation. + if !c.config.InsecureSkipVerify { + if len(session.verifiedChains) == 0 { + // The original connection had InsecureSkipVerify, while this doesn't. + return cacheKey, nil, nil, nil + } + serverCert := session.serverCertificates[0] + if c.config.time().After(serverCert.NotAfter) { + // Expired certificate, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { + return cacheKey, nil, nil, nil + } + } + + if session.vers != VersionTLS13 { + // In TLS 1.2 the cipher suite must match the resumed session. Ensure we + // are still offering it. + if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { + return cacheKey, nil, nil, nil + } + + hello.sessionTicket = session.sessionTicket + return + } + + // Check that the session ticket is not expired. + if c.config.time().After(session.useBy) { + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + + // In TLS 1.3 the KDF hash must match the resumed session. Ensure we + // offer at least one cipher suite with that hash. + cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) + if cipherSuite == nil { + return cacheKey, nil, nil, nil + } + cipherSuiteOk := false + for _, offeredID := range hello.cipherSuites { + offeredSuite := cipherSuiteTLS13ByID(offeredID) + if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return cacheKey, nil, nil, nil + } + + // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. + ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + identity := pskIdentity{ + label: session.sessionTicket, + obfuscatedTicketAge: ticketAge + session.ageAdd, + } + hello.pskIdentities = []pskIdentity{identity} + hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} + + // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. + psk := cipherSuite.expandLabel(session.masterSecret, "resumption", + session.nonce, cipherSuite.hash.Size()) + earlySecret = cipherSuite.extract(psk, nil) + binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + transcript := cipherSuite.hash.New() + transcript.Write(hello.marshalWithoutBinders()) + pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} + hello.updateBinders(pskBinders) + + return +} + +func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { + peerVersion := serverHello.vers + if serverHello.supportedVersion != 0 { + peerVersion = serverHello.supportedVersion + } + + vers, ok := c.config.mutualVersion([]uint16{peerVersion}) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) + } + + c.vers = vers + c.haveVers = true + c.in.version = vers + c.out.version = vers + + return nil +} + +// Does the handshake, either a full one or resumes old session. Requires hs.c, +// hs.hello, hs.serverHello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() (err error) { + c := hs.c + + if c.handshakeStatus == 2 { + c.didResume, err = hs.processServerHello() + if err != nil { + return err + } + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if c.didResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + c.handshakeStatus = 3 + //c.buffering = true + } + + if c.didResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + switch c.handshakeStatus { + case 3: + if err := hs.doFullHandshakeStep1(); err != nil { + return err + } + c.handshakeStatus = 4 + return nil + case 4: + if err := hs.doFullHandshakeStep2(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + _, err := c.flush() + + c.handshakeStatus = 5 + return err + case 5: + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) + + c.handshakeStatus = 255 + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + + if hs.cacheKey != "" && hs.session != nil && hs.oldsession != hs.session { + c.config.ClientSessionCache.Put(hs.cacheKey, hs.session) + } + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshakeStep1() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + if c.handshakes == 1 || len(c.peerCertificates) == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + return nil +} +func (hs *clientHandshakeState) doFullHandshakeStep2() error { + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + c.ocspResponse = cs.response + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + cri := certificateRequestInfoFromMsg(c.vers, certReq) + if chainToSend, err = c.getClientCertificate(cri); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg := new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{} + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + certVerify.hasSignatureAlgorithm = true + certVerify.signatureAlgorithm = signatureAlgorithm + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if err := hs.pickCipherSuite(); err != nil { + return false, err + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + clientDidALPN := len(hs.hello.alpnProtocols) > 0 + serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 + + if !clientDidALPN && serverHasALPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server advertised unrequested ALPN extension") + } + + if serverHasALPN { + c.clientProtocol = hs.serverHello.alpnProtocol + c.clientProtocolFallback = false + } + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret, peerCerts, and ocspResponse from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + // Let the ServerHello SCTs override the session SCTs from the original + // connection, if any are provided + if len(c.scts) == 0 && len(hs.session.scts) != 0 { + c.scts = hs.session.scts + } + return true, nil +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &ClientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// verifyServerCertificate parses and verifies the provided chain, setting +// c.verifiedChains and c.peerCertificates or sending the appropriate alert. +func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS +// <= 1.2 CertificateRequest, making an effort to fill in missing information. +func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { + cri := &CertificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + Version: vers, + } + + var rsaAvail, ecAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecAvail = true + } + } + + if !certReq.hasSignatureAlgorithm { + // Prior to TLS 1.2, signature schemes did not exist. In this case we + // make up a list based on the acceptable certificate types, to help + // GetClientCertificate and SupportsCertificate select the right certificate. + // The hash part of the SignatureScheme is a lie here, because + // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. + switch { + case rsaAvail && ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case rsaAvail: + cri.SignatureSchemes = []SignatureScheme{ + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + } + } + return cri + } + + // Filter the signature schemes based on the certificate types. + // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). + cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) + for _, sigScheme := range certReq.supportedSignatureAlgorithms { + sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) + if err != nil { + continue + } + switch sigType { + case signatureECDSA, signatureEd25519: + if ecAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + case signatureRSAPSS, signaturePKCS1v15: + if rsaAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + } + } + + return cri +} + +func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { + if c.config.GetClientCertificate != nil { + return c.config.GetClientCertificate(cri) + } + + for _, chain := range c.config.Certificates { + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { + if len(config.ServerName) > 0 { + return config.ServerName + } + return serverAddr.String() +} + +// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol +// given list of possible protocols and a list of the preference order. The +// first list must not be empty. It returns the resulting protocol and flag +// indicating if the fallback case was reached. +func mutualProtocol(protos, preferenceProtos []string) (string, bool) { + for _, s := range preferenceProtos { + for _, c := range protos { + if s == c { + return s, false + } + } + } + + return protos[0], true +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See RFC 6066, Section 3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go new file mode 100644 index 000000000..5cf35aec6 --- /dev/null +++ b/pkg/tls/handshake_client_tls13.go @@ -0,0 +1,675 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "time" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheParams ecdheParameters + + session *ClientSessionState + earlySecret []byte + binderKey []byte + + certReq *certificateRequestMsgTLS13 + usingPSK bool + sentDummyCCS bool + suite *cipherSuiteTLS13 + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + cacheKey string +} + +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. +func (hs *clientHandshakeStateTLS13) handshake() error { + c := hs.c + + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, + // sections 4.1.2 and 4.1.3. + if c.handshakes > 255 { + return errors.New("tls: server selected TLS 1.3 in a renegotiation") + } + + // Consistency check on the presence of a keyShare and its parameters. + if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + return c.sendAlert(alertInternalError) + } + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + hs.transcript = hs.suite.hash.New() + hs.transcript.Write(hs.hello.marshal()) + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.processHelloRetryRequest(); err != nil { + return err + } + } + hs.transcript.Write(hs.serverHello.marshal()) + + if err := hs.processServerHello(); err != nil { + return err + } + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.establishHandshakeKeys(); err != nil { + return err + } + if err := hs.readServerParameters(); err != nil { + return err + } + if err := hs.readServerCertificate(); err != nil { + return err + } + if err := hs.readServerFinished(); err != nil { + return err + } + if err := hs.sendClientCertificate(); err != nil { + return err + } + if err := hs.sendClientFinished(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + + c.handshakeStatus = 255 + + return nil +} + +// checkServerHelloOrHRR does validity checks that apply to both ServerHello and +// HelloRetryRequest messages. It sets hs.suite. +func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { + c := hs.c + + if hs.serverHello.supportedVersion == 0 { + c.sendAlert(alertMissingExtension) + return errors.New("tls: server selected TLS 1.3 using the legacy version field") + } + + if hs.serverHello.supportedVersion != VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid version after a HelloRetryRequest") + } + + if hs.serverHello.vers != VersionTLS12 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an incorrect legacy version") + } + + if hs.serverHello.ocspStapling || + hs.serverHello.ticketSupported || + hs.serverHello.secureRenegotiationSupported || + len(hs.serverHello.secureRenegotiation) != 0 || + len(hs.serverHello.alpnProtocol) != 0 || + len(hs.serverHello.scts) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") + } + + if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not echo the legacy session ID") + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported compression format") + } + + selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if hs.suite != nil && selectedSuite != hs.suite { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server changed cipher suite after a HelloRetryRequest") + } + if selectedSuite == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = selectedSuite + c.cipherSuite = hs.suite.id + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and +// resends hs.hello, and reads the new ServerHello into hs.serverHello. +func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. (The idea is that the server might offload transcript + // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + hs.transcript.Write(hs.serverHello.marshal()) + + // The only HelloRetryRequest extensions we support are key_share and + // cookie, and clients must abort the handshake if the HRR would not result + // in any change in the ClientHello. + if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest message") + } + + if hs.serverHello.cookie != nil { + hs.hello.cookie = hs.serverHello.cookie + } + if hs.serverHello.serverShare.group != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received malformed key_share extension") + } + + // If the server sent a key_share extension selecting a group, ensure it's + // a group we advertised but did not send a key share for, and send a key + // share for it this time. + if curveID := hs.serverHello.selectedGroup; curveID != 0 { + curveOK := false + for _, id := range hs.hello.supportedCurves { + if id == curveID { + curveOK = true + break + } + } + if !curveOK { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + if hs.ecdheParams.CurveID() == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheParams = params + hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + hs.hello.raw = nil + if len(hs.hello.pskIdentities) > 0 { + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash == hs.suite.hash { + // Update binders and obfuscated_ticket_age. + ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) + hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + + transcript := hs.suite.hash.New() + transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + transcript.Write(chHash) + transcript.Write(hs.serverHello.marshal()) + transcript.Write(hs.hello.marshalWithoutBinders()) + pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} + hs.hello.updateBinders(pskBinders) + } else { + // Server selected a cipher suite incompatible with the PSK. + hs.hello.pskIdentities = nil + hs.hello.pskBinders = nil + } + } + + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + hs.serverHello = serverHello + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) processServerHello() error { + c := hs.c + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: server sent two HelloRetryRequest messages") + } + + if len(hs.serverHello.cookie) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a cookie in a normal ServerHello") + } + + if hs.serverHello.selectedGroup != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: malformed key_share extension") + } + + if hs.serverHello.serverShare.group == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not send a key share") + } + if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + if !hs.serverHello.selectedIdentityPresent { + return nil + } + + if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK") + } + + if len(hs.hello.pskIdentities) != 1 || hs.session == nil { + return c.sendAlert(alertInternalError) + } + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash != hs.suite.hash { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK and cipher suite pair") + } + + hs.usingPSK = true + c.didResume = true + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + c.scts = hs.session.scts + return nil +} + +func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { + c := hs.c + + sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + if sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + + earlySecret := hs.earlySecret + if !hs.usingPSK { + earlySecret = hs.suite.extract(nil, nil) + } + handshakeSecret := hs.suite.extract(sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.out.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerParameters() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + hs.transcript.Write(encryptedExtensions.marshal()) + + if len(encryptedExtensions.alpnProtocol) != 0 && len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server advertised unrequested ALPN extension") + } + c.clientProtocol = encryptedExtensions.alpnProtocol + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerCertificate() error { + c := hs.c + + // Either a PSK or a certificate is always used, but not both. + // See RFC 8446, Section 4.1.1. + if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certReq, ok := msg.(*certificateRequestMsgTLS13) + if ok { + hs.transcript.Write(certReq.marshal()) + + hs.certReq = certReq + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + if len(certMsg.certificate.Certificate) == 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received empty certificates message") + } + hs.transcript.Write(certMsg.marshal()) + + c.scts = certMsg.certificate.SignedCertificateTimestamps + c.ocspResponse = certMsg.certificate.OCSPStaple + + if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + if !hmac.Equal(expectedMAC, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid server finished hash") + } + + hs.transcript.Write(finished.marshal()) + + // Derive secrets that take context through the server Finished. + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { + c := hs.c + + if hs.certReq == nil { + return nil + } + + cert, err := c.getClientCertificate(&CertificateRequestInfo{ + AcceptableCAs: hs.certReq.certificateAuthorities, + SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, + Version: c.vers, + }) + if err != nil { + return err + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *cert + certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + // If we sent an empty certificate message, skip the CertificateVerify. + if len(cert.Certificate) == 0 { + return nil + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + + certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) + if err != nil { + // getClientCertificate returned a certificate incompatible with the + // CertificateRequestInfo supported signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + + if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + } + + return nil +} + +func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received new session ticket from a client") + } + + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return nil + } + + // See RFC 8446, Section 4.6.1. + if msg.lifetime == 0 { + return nil + } + lifetime := time.Duration(msg.lifetime) * time.Second + if lifetime > maxSessionTicketLifetime { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: received a session ticket with invalid lifetime") + } + + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil || c.resumptionSecret == nil { + return c.sendAlert(alertInternalError) + } + + // Save the resumption_master_secret and nonce instead of deriving the PSK + // to do the least amount of work on NewSessionTicket messages before we + // know if the ticket will be used. Forward secrecy of resumed connections + // is guaranteed by the requirement for pskModeDHE. + session := &ClientSessionState{ + sessionTicket: msg.label, + vers: c.vers, + cipherSuite: c.cipherSuite, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + nonce: msg.nonce, + useBy: c.config.time().Add(lifetime), + ageAdd: msg.ageAdd, + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, session) + + return nil +} diff --git a/pkg/tls/handshake_messages.go b/pkg/tls/handshake_messages.go new file mode 100644 index 000000000..b5f81e443 --- /dev/null +++ b/pkg/tls/handshake_messages.go @@ -0,0 +1,1809 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "fmt" + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + earlyData bool + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) + } + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +// marshalWithoutBinders returns the ClientHello through the +// PreSharedKeyExtension.identities field, according to RFC 8446, Section +// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. +func (m *clientHelloMsg) marshalWithoutBinders() []byte { + bindersLen := 2 // uint16 length prefix + for _, binder := range m.pskBinders { + bindersLen += 1 // uint8 length prefix + bindersLen += len(binder) + } + + fullMessage := m.marshal() + return fullMessage[:len(fullMessage)-bindersLen] +} + +// updateBinders updates the m.pskBinders field, if necessary updating the +// cached marshaled representation. The supplied binders must have the same +// length as the current m.pskBinders. +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { + if len(pskBinders) != len(m.pskBinders) { + panic("tls: internal error: pskBinders length mismatch") + } + for i := range m.pskBinders { + if len(pskBinders[i]) != len(m.pskBinders[i]) { + panic("tls: internal error: pskBinders length mismatch") + } + } + m.pskBinders = pskBinders + if m.raw != nil { + lenWithoutBinders := len(m.marshalWithoutBinders()) + // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. + b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + if len(b.BytesOrPanic()) != len(m.raw) { + panic("tls: internal error: failed to update binders") + } + } +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { + return false + } + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { + return false + } + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + m.cipherSuites = append(m.cipherSuites, suite) + } + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { + return false + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionServerName: + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return false + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return false + } + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. + return false + } + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false + } + } + case extensionStatusRequest: + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP + case extensionSupportedCurves: + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { + return false + } + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + case extensionSessionTicket: + // RFC 5077, Section 3.2 + m.ticketSupported = true + extData.ReadBytes(&m.sessionTicket, len(extData)) + case extensionSignatureAlgorithms: + // RFC 5246, Section 7.4.1.4.1 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + ocspStapling bool + ticketSupported bool + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + supportedPoints []uint8 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } + + if s.Empty() { + // ServerHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSessionTicket: + m.ticketSupported = true + case extensionRenegotiationInfo: + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.scts = append(m.scts, sct) + } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + b.AddUint16(extensionCertificateAuthorities) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ca) + }) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionCertificateAuthorities: + var auths cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { + return false + } + for !auths.Empty() { + var ca []byte + if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { + return false + } + m.certificateAuthorities = append(m.certificateAuthorities, ca) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + + certificate := m.certificate + if !m.ocspStapling { + certificate.OCSPStaple = nil + } + if !m.scts { + certificate.SignedCertificateTimestamps = nil + } + marshalCertificate(b, certificate) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if certificate.OCSPStaple != nil { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(certificate.OCSPStaple) + }) + }) + } + if certificate.SignedCertificateTimestamps != nil { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !unmarshalCertificate(&s, &m.certificate) || + !s.Empty() { + return false + } + + m.scts = m.certificate.SignedCertificateTimestamps != nil + m.ocspStapling = m.certificate.OCSPStaple != nil + + return true +} + +func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + certificate.Certificate = append(certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || + len(certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + certificate.SignedCertificateTimestamps = append( + certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 4346, Section 7.4.4. + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAlgorithm { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAlgorithm { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAlgorithm { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + if !s.Skip(4) { // message type and uint24 length field + return false + } + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } + } + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 5077, Section 3.3. + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} diff --git a/pkg/tls/handshake_server.go b/pkg/tls/handshake_server.go new file mode 100644 index 000000000..ede1c52e0 --- /dev/null +++ b/pkg/tls/handshake_server.go @@ -0,0 +1,869 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "io" + "strconv" + "time" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ecdheOk bool + ecSignOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + cert *Certificate + + keyAgreement keyAgreement + certReq *certificateRequestMsg +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake() error { + // If this is the first server handshake, we generate a random key to + // encrypt the tickets with. + //gnet不能进行阻塞二次读取,所以会分几条消息重复执行此方法,status也会分很多个状态 + if c.hs == nil { + //首次执行要初始化对象 + clientHello, err := c.readClientHello() + if err != nil { + return err + } + if c.vers == VersionTLS13 { + c.hs = &serverHandshakeStateTLS13{ + c: c, + clientHello: clientHello, + } + + } else { + c.hs = &serverHandshakeState{ + c: c, + clientHello: clientHello, + } + } + } + return c.hs.handshake() +} + +func (hs *serverHandshakeState) handshake() error { + c := hs.c + if c.handshakeStatus == 0 { + if err := hs.processClientHello(); err != nil { + return err + } + + // For an overview of TLS handshaking, see RFC 5246, Section 7.3. + //c.buffering = true + } + + if hs.checkForResumption() { + switch c.handshakeStatus { + case 0: + // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.handshakeStatus = 1 + return nil + case 1: + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + return err + } + + default: + return errors.New("错误的status状态" + strconv.Itoa(int(c.handshakeStatus))) + } + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + switch c.handshakeStatus { + case 0: + if err := hs.pickCipherSuite(); err != nil { + return err + } + if err := hs.doFullHandshakeStep1(); err != nil { + return err + } + c.handshakeStatus = 3 + return nil + case 3: + if err := hs.doFullHandshakeStep2(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + c.handshakeStatus = 4 + if c.rawInput.Len() < 5 { + return nil + } + fallthrough + case 4: + if err := hs.readFinished(c.clientFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = true + //c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(nil); err != nil { + + return err + } + if _, err := c.flush(); err != nil { + return err + } + } + + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) + c.handshakeStatus = 255 + return nil +} + +// readClientHello reads a ClientHello message and selects the protocol version. +func (c *Conn) readClientHello() (*clientHelloMsg, error) { + msg, err := c.readHandshake() + if err != nil { + return nil, err + } + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return nil, unexpectedMessageError(clientHello, msg) + } + + var configForClient *Config + originalConfig := c.config + if c.config.GetConfigForClient != nil { + chi := clientHelloInfo(c, clientHello) + if configForClient, err = c.config.GetConfigForClient(chi); err != nil { + c.sendAlert(alertInternalError) + return nil, err + } else if configForClient != nil { + c.config = configForClient + } + } + c.ticketKeys = originalConfig.ticketKeys(configForClient) + + clientVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(clientHello.vers) + } + c.vers, ok = c.config.mutualVersion(clientVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) + } + c.haveVers = true + c.in.version = c.vers + c.out.version = c.vers + + return clientHello, nil +} + +func (hs *serverHandshakeState) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.random = make([]byte, 32) + serverRandom := hs.hello.random + // Downgrade protection canaries. See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { + if c.vers == VersionTLS12 { + copy(serverRandom[24:], downgradeCanaryTLS12) + } else { + copy(serverRandom[24:], downgradeCanaryTLS11) + } + serverRandom = serverRandom[:24] + } + _, err := io.ReadFull(c.config.rand(), serverRandom) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + } + } + + hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + + if hs.ecdheOk { + // Although omitting the ec_point_formats extension is permitted, some + // old OpenSSL version will refuse to handshake if not present. + // + // Per RFC 4492, section 5.1.2, implementations MUST support the + // uncompressed point format. See golang.org/issue/31943. + hs.hello.supportedPoints = []uint8{pointFormatUncompressed} + } + + if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecSignOk = true + case ed25519.PublicKey: + hs.ecSignOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + return nil +} + +// supportsECDHE returns whether ECDHE key exchanges can be used with this +// pre-TLS 1.3 client. +func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool { + supportsCurve := false + for _, curve := range supportedCurves { + if c.supportsCurve(curve) { + supportsCurve = true + break + } + } + + supportsPointFormat := false + for _, pointFormat := range supportedPoints { + if pointFormat == pointFormatUncompressed { + supportsPointFormat = true + break + } + } + + return supportsCurve && supportsPointFormat +} + +func (hs *serverHandshakeState) pickCipherSuite() error { + c := hs.c + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = c.config.cipherSuites() + supportedList = hs.clientHello.cipherSuites + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = c.config.cipherSuites() + } + + hs.suite = selectCipherSuite(preferenceList, supportedList, hs.cipherSuiteOk) + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if hs.clientHello.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return nil +} + +func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + if !hs.ecdheOk { + return false + } + if c.flags&suiteECSign != 0 { + if !hs.ecSignOk { + return false + } + } else if !hs.rsaSignOk { + return false + } + } else if !hs.rsaDecryptOk { + return false + } + if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.config.SessionTicketsDisabled { + return false + } + + plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) + if plaintext == nil { + return false + } + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + ok := hs.sessionState.unmarshal(plaintext) + if !ok { + return false + } + + createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + c.config.cipherSuites(), hs.cipherSuiteOk) + if hs.suite == nil { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := c.processCertsFromClient(Certificate{ + Certificate: hs.sessionState.certificates, + }); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshakeStep1() error { + c := hs.c + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + hs.keyAgreement = hs.suite.ka(c.vers) + skx, err := hs.keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + hs.certReq = new(certificateRequestMsg) + hs.certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + hs.certReq.hasSignatureAlgorithm = true + hs.certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + hs.certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(hs.certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + _, err = c.flush() + return err + +} +func (hs *serverHandshakeState) doFullHandshakeStep2() error { + c := hs.c + + var pub crypto.PublicKey // public key for client auth, if any + msg, err := c.readHandshake() + if err != nil { + return err + } + + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if c.config.ClientAuth >= RequestClientCert { + certMsg, ok := msg.(*certificateMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(Certificate{ + Certificate: certMsg.certificates, + }); err != nil { + return err + } + if len(certMsg.certificates) != 0 { + pub = c.peerCertificates[0].PublicKey + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := hs.keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, hs.certReq.supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (c *Conn) processCertsFromClient(certificate Certificate) error { + certificates := certificate.Certificate + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to verify client certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { + supportedVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(clientHello.vers) + } + + return &ClientHelloInfo{ + CipherSuites: clientHello.cipherSuites, + ServerName: clientHello.serverName, + SupportedCurves: clientHello.supportedCurves, + SupportedPoints: clientHello.supportedPoints, + SignatureSchemes: clientHello.supportedSignatureAlgorithms, + SupportedProtos: clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: c.conn, + config: c.config, + } +} diff --git a/pkg/tls/handshake_server_tls13.go b/pkg/tls/handshake_server_tls13.go new file mode 100644 index 000000000..467612649 --- /dev/null +++ b/pkg/tls/handshake_server_tls13.go @@ -0,0 +1,869 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "io" + "time" +) + +// maxClientPSKIdentities is the number of client PSK identities the server will +// attempt to validate. It will ignore the rest not to let cheap ClientHello +// messages cause too much work in session ticket decryption attempts. +const maxClientPSKIdentities = 5 + +type serverHandshakeStateTLS13 struct { + c *Conn + clientHello *clientHelloMsg + hello *serverHelloMsg + sentDummyCCS bool + usingPSK bool + suite *cipherSuiteTLS13 + cert *Certificate + sigAlg SignatureScheme + earlySecret []byte + sharedKey []byte + handshakeSecret []byte + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + transcript hash.Hash + clientFinished []byte +} + +func (hs *serverHandshakeStateTLS13) handshake() error { + c := hs.c + switch c.handshakeStatus { + case 0: + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. + if err := hs.processClientHello(); err != nil { + + return err + } + if err := hs.checkForResumption(); err != nil { + + return err + } + if err := hs.pickCertificate(); err != nil { + + return err + } + //c.buffering = true + if err := hs.sendServerParameters(); err != nil { + + return err + } + if err := hs.sendServerCertificate(); err != nil { + + return err + } + if err := hs.sendServerFinished(); err != nil { + + return err + } + // Note that at this point we could start sending application data without + // waiting for the client's second flight, but the application might not + // expect the lack of replay protection of the ClientHello parameters. + if _, err := c.flush(); err != nil { + + return err + } + c.handshakeStatus = 1 + case 1: + if err := hs.readClientCertificate(); err != nil { + return err + } + if err := hs.readClientFinished(); err != nil { + + return err + } + c.handshakeStatus = 255 + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + + // TLS 1.3 froze the ServerHello.legacy_version field, and uses + // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. + hs.hello.vers = VersionTLS12 + hs.hello.supportedVersion = c.vers + + if len(hs.clientHello.supportedVersions) == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") + } + + // Abort if the client is doing a fallback and landing lower than what we + // support. See RFC 7507, which however does not specify the interaction + // with supported_versions. The only difference is that with + // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] + // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, + // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to + // TLS 1.2, because a TLS 1.3 server would abort here. The situation before + // supported_versions was not better because there was just no way to do a + // TLS 1.4 handshake without risking the server selecting TLS 1.3. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // Use c.vers instead of max(supported_versions) because an attacker + // could defeat this by adding an arbitrary high version otherwise. + if c.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + if len(hs.clientHello.compressionMethods) != 1 || + hs.clientHello.compressionMethods[0] != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: TLS 1.3 client supports illegal compression methods") + } + + hs.hello.random = make([]byte, 32) + if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + if hs.clientHello.earlyData { + // See RFC 8446, Section 4.2.10 for the complicated behavior required + // here. The scenario is that a different server at our address offered + // to accept early data in the past, which we can't handle. For now, all + // 0-RTT enabled session tickets need to expire before a Go server can + // replace a server or join a pool. That's the same requirement that + // applies to mixing or replacing with any TLS 1.2 server. + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent unexpected early data") + } + + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.compressionMethod = compressionNone + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = defaultCipherSuitesTLS13() + supportedList = hs.clientHello.cipherSuites + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = defaultCipherSuitesTLS13() + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(supportedList, suiteID) + if hs.suite != nil { + break + } + } + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + hs.hello.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + + // Pick the ECDHE group in server preference order, but give priority to + // groups with a key share, to avoid a HelloRetryRequest round-trip. + var selectedGroup CurveID + var clientKeyShare *keyShare +GroupSelection: + for _, preferredGroup := range c.config.curvePreferences() { + for _, ks := range hs.clientHello.keyShares { + if ks.group == preferredGroup { + selectedGroup = ks.group + clientKeyShare = &ks + break GroupSelection + } + } + if selectedGroup != 0 { + continue + } + for _, group := range hs.clientHello.supportedCurves { + if group == preferredGroup { + selectedGroup = group + break + } + } + } + if selectedGroup == 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no ECDHE curve supported by both client and server") + } + if clientKeyShare == nil { + if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + return err + } + clientKeyShare = &hs.clientHello.keyShares[0] + } + + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) + if hs.sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + + c.serverName = hs.clientHello.serverName + return nil +} + +func (hs *serverHandshakeStateTLS13) checkForResumption() error { + c := hs.c + + if c.config.SessionTicketsDisabled { + return nil + } + + modeOK := false + for _, mode := range hs.clientHello.pskModes { + if mode == pskModeDHE { + modeOK = true + break + } + } + if !modeOK { + return nil + } + + if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid or missing PSK binders") + } + if len(hs.clientHello.pskIdentities) == 0 { + return nil + } + + for i, identity := range hs.clientHello.pskIdentities { + if i >= maxClientPSKIdentities { + break + } + + plaintext, _ := c.decryptTicket(identity.label) + if plaintext == nil { + continue + } + sessionState := new(sessionStateTLS13) + if ok := sessionState.unmarshal(plaintext); !ok { + continue + } + + createdAt := time.Unix(int64(sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + continue + } + + // We don't check the obfuscated ticket age because it's affected by + // clock skew and it's only a freshness signal useful for shrinking the + // window for replay attacks, which don't affect us as we don't do 0-RTT. + + pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) + if pskSuite == nil || pskSuite.hash != hs.suite.hash { + continue + } + + // PSK connections don't re-establish client certificates, but carry + // them over in the session ticket. Ensure the presence of client certs + // in the ticket is consistent with the configured requirements. + sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + continue + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + continue + } + + psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) + hs.earlySecret = hs.suite.extract(psk, nil) + binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + // Clone the transcript in case a HelloRetryRequest was recorded. + transcript := cloneHash(hs.transcript, hs.suite.hash) + if transcript == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: failed to clone hash") + } + transcript.Write(hs.clientHello.marshalWithoutBinders()) + pskBinder := hs.suite.finishedHash(binderKey, transcript) + if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid PSK binder") + } + + c.didResume = true + if err := c.processCertsFromClient(sessionState.certificate); err != nil { + return err + } + + hs.hello.selectedIdentityPresent = true + hs.hello.selectedIdentity = uint16(i) + hs.usingPSK = true + return nil + } + + return nil +} + +// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// interfaces implemented by standard library hashes to clone the state of in +// to a new instance of h. It returns nil if the operation fails. +func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + // Recreate the interface to avoid importing encoding. + type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) + UnmarshalBinary(data []byte) error + } + marshaler, ok := in.(binaryMarshaler) + if !ok { + return nil + } + state, err := marshaler.MarshalBinary() + if err != nil { + return nil + } + out := h.New() + unmarshaler, ok := out.(binaryMarshaler) + if !ok { + return nil + } + if err := unmarshaler.UnmarshalBinary(state); err != nil { + return nil + } + return out +} + +func (hs *serverHandshakeStateTLS13) pickCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. + if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { + return c.sendAlert(alertMissingExtension) + } + + certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) + if err != nil { + // getCertificate returned a certificate that is unsupported or + // incompatible with the client's signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + hs.cert = certificate + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. See RFC 8446, Section 4.4.1. + hs.transcript.Write(hs.clientHello.marshal()) + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + + helloRetryRequest := &serverHelloMsg{ + vers: hs.hello.vers, + random: helloRetryRequestRandom, + sessionId: hs.hello.sessionId, + cipherSuite: hs.hello.cipherSuite, + compressionMethod: hs.hello.compressionMethod, + supportedVersion: hs.hello.supportedVersion, + selectedGroup: selectedGroup, + } + + hs.transcript.Write(helloRetryRequest.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientHello, msg) + } + + if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client sent invalid key share in second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client indicated early data in second ClientHello") + } + + if illegalClientHelloChange(clientHello, hs.clientHello) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client illegally modified second ClientHello") + } + + hs.clientHello = clientHello + return nil +} + +// illegalClientHelloChange reports whether the two ClientHello messages are +// different, with the exception of the changes allowed before and after a +// HelloRetryRequest. See RFC 8446, Section 4.1.2. +func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { + if len(ch.supportedVersions) != len(ch1.supportedVersions) || + len(ch.cipherSuites) != len(ch1.cipherSuites) || + len(ch.supportedCurves) != len(ch1.supportedCurves) || + len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || + len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || + len(ch.alpnProtocols) != len(ch1.alpnProtocols) { + return true + } + for i := range ch.supportedVersions { + if ch.supportedVersions[i] != ch1.supportedVersions[i] { + return true + } + } + for i := range ch.cipherSuites { + if ch.cipherSuites[i] != ch1.cipherSuites[i] { + return true + } + } + for i := range ch.supportedCurves { + if ch.supportedCurves[i] != ch1.supportedCurves[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithms { + if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithmsCert { + if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { + return true + } + } + for i := range ch.alpnProtocols { + if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { + return true + } + } + return ch.vers != ch1.vers || + !bytes.Equal(ch.random, ch1.random) || + !bytes.Equal(ch.sessionId, ch1.sessionId) || + !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || + ch.serverName != ch1.serverName || + ch.ocspStapling != ch1.ocspStapling || + !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || + ch.ticketSupported != ch1.ticketSupported || + !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || + ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || + !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || + ch.scts != ch1.scts || + !bytes.Equal(ch.cookie, ch1.cookie) || + !bytes.Equal(ch.pskModes, ch1.pskModes) +} + +func (hs *serverHandshakeStateTLS13) sendServerParameters() error { + c := hs.c + + hs.transcript.Write(hs.clientHello.marshal()) + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + earlySecret := hs.earlySecret + if earlySecret == nil { + earlySecret = hs.suite.extract(nil, nil) + } + hs.handshakeSecret = hs.suite.extract(hs.sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.in.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + encryptedExtensions := new(encryptedExtensionsMsg) + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { + encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + } + } + + hs.transcript.Write(encryptedExtensions.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK +} + +func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + if hs.requestClientCert() { + // Request a client certificate + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + + hs.transcript.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *hs.cert + certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + certVerifyMsg.signatureAlgorithm = hs.sigAlg + + sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + public := hs.cert.PrivateKey.(crypto.Signer).Public() + if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && + rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS + c.sendAlert(alertHandshakeFailure) + } else { + c.sendAlert(alertInternalError) + } + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) sendServerFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + // Derive secrets that take context through the server Finished. + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + // If we did not request client certificates, at this point we can + // precompute the client finished and roll the transcript forward to send + // session tickets in our first flight. + if !hs.requestClientCert() { + if err := hs.sendSessionTickets(); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { + if hs.c.config.SessionTicketsDisabled { + return false + } + + // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. + for _, pskMode := range hs.clientHello.pskModes { + if pskMode == pskModeDHE { + return true + } + } + return false +} + +func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { + c := hs.c + + hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + finishedMsg := &finishedMsg{ + verifyData: hs.clientFinished, + } + hs.transcript.Write(finishedMsg.marshal()) + + if !hs.shouldSendSessionTickets() { + return nil + } + + resumptionSecret := hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + + m := new(newSessionTicketMsgTLS13) + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionStateTLS13{ + cipherSuite: hs.suite.id, + createdAt: uint64(c.config.time().Unix()), + resumptionSecret: resumptionSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, + } + var err error + m.label, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientCertificate() error { + c := hs.c + + if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + // If we requested a client certificate, then the client must send a + // certificate message. If it's empty, no CertificateVerify is sent. + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.transcript.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(certMsg.certificate); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if len(certMsg.certificate.Certificate) != 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + } + + // If we waited until the client certificates to send session tickets, we + // are ready to do it now. + if err := hs.sendSessionTickets(); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + if !hmac.Equal(hs.clientFinished, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid client finished hash") + } + + c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + + return nil +} diff --git a/pkg/tls/key_agreement.go b/pkg/tls/key_agreement.go new file mode 100644 index 000000000..7e6534bd4 --- /dev/null +++ b/pkg/tls/key_agreement.go @@ -0,0 +1,334 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "errors" + "fmt" + "io" +) + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext := ckx.ciphertext[2:] + + priv, ok := cert.PrivateKey.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS #1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function (for >= TLS 1.2) or using a default based on +// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't +// do pre-hashing, it returns the concatenation of the slices. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { + if sigType == signatureEd25519 { + var signed []byte + for _, slice := range slices { + signed = append(signed, slice...) + } + return signed + } + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest + } + if sigType == signatureECDSA { + return sha1Hash(slices) + } + return md5SHA1Hash(slices) +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// be ECDSA, Ed25519 or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + params ecdheParameters + + // ckx and preMasterSecret are generated in processServerKeyExchange + // and returned in generateClientKeyExchange. + ckx *clientKeyExchangeMsg + preMasterSecret []byte +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveID CurveID + for _, c := range clientHello.supportedCurves { + if config.supportsCurve(c) { + curveID = c + break + } + } + + if curveID == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, err + } + ka.params = params + + // See RFC 4492, Section 5.4. + ecdhePublic := params.PublicKey() + serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHEParams[0] = 3 // named curve + serverECDHEParams[1] = byte(curveID >> 8) + serverECDHEParams[2] = byte(curveID) + serverECDHEParams[3] = byte(len(ecdhePublic)) + copy(serverECDHEParams[4:], ecdhePublic) + + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) + } + + var signatureAlgorithm SignatureScheme + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + if err != nil { + return nil, err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return nil, err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + if err != nil { + return nil, err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) + + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := priv.Sign(config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHEParams) + k := skx.key[len(serverECDHEParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) + if preMasterSecret == nil { + return nil, errClientKeyExchange + } + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHEParams := skx.key[:4+publicLen] + publicKey := serverECDHEParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return errors.New("tls: server selected unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return err + } + ka.params = params + + ka.preMasterSecret = params.SharedKey(publicKey) + if ka.preMasterSecret == nil { + return errServerKeyExchange + } + + ourPublicKey := params.PublicKey() + ka.ckx = new(clientKeyExchangeMsg) + ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) + ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) + copy(ka.ckx.ciphertext[1:], ourPublicKey) + + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) + if err != nil { + return err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + return nil +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.ckx == nil { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + return ka.preMasterSecret, ka.ckx, nil +} diff --git a/pkg/tls/key_schedule.go b/pkg/tls/key_schedule.go new file mode 100644 index 000000000..314016979 --- /dev/null +++ b/pkg/tls/key_schedule.go @@ -0,0 +1,199 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto/elliptic" + "crypto/hmac" + "errors" + "hash" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// finishedHash generates the Finished verify_data or PskBinderEntry according +// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey +// selection. +func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { + finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + verifyData := hmac.New(c.hash.New, finishedKey) + verifyData.Write(transcript.Sum(nil)) + return verifyData.Sum(nil) +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} + +// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// according to RFC 8446, Section 4.2.8.2. +type ecdheParameters interface { + CurveID() CurveID + PublicKey() []byte + SharedKey(peerPublicKey []byte) []byte +} + +func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { + if curveID == X25519 { + privateKey := make([]byte, curve25519.ScalarSize) + if _, err := io.ReadFull(rand, privateKey); err != nil { + return nil, err + } + publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) + if err != nil { + return nil, err + } + return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + p := &nistParameters{curveID: curveID} + var err error + p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) + if err != nil { + return nil, err + } + return p, nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } +} + +type nistParameters struct { + privateKey []byte + x, y *big.Int // public key + curveID CurveID +} + +func (p *nistParameters) CurveID() CurveID { + return p.curveID +} + +func (p *nistParameters) PublicKey() []byte { + curve, _ := curveForCurveID(p.curveID) + return elliptic.Marshal(curve, p.x, p.y) +} + +func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { + curve, _ := curveForCurveID(p.curveID) + // Unmarshal also checks whether the given point is on the curve. + x, y := elliptic.Unmarshal(curve, peerPublicKey) + if x == nil { + return nil + } + + xShared, _ := curve.ScalarMult(x, y, p.privateKey) + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) +} + +type x25519Parameters struct { + privateKey []byte + publicKey []byte +} + +func (p *x25519Parameters) CurveID() CurveID { + return X25519 +} + +func (p *x25519Parameters) PublicKey() []byte { + return p.publicKey[:] +} + +func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { + sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) + if err != nil { + return nil + } + return sharedKey +} diff --git a/pkg/tls/prf.go b/pkg/tls/prf.go new file mode 100644 index 000000000..13bfa009c --- /dev/null +++ b/pkg/tls/prf.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +const ( + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See RFC 5246, Section 8.1. +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, Section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns the handshake messages so far, pre-hashed if +// necessary, suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { + if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") + } + + if sigType == signatureEd25519 { + return h.buffer + } + + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil) + } + + if sigType == signatureECDSA { + return h.server.Sum(nil) + } + + return h.Sum() +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} + +// noExportedKeyingMaterial is used as a value of +// ConnectionState.ekm when renegotiation is enabled and thus +// we wish to fail all key-material export requests. +func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") +} + +// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. +func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { + return func(label string, context []byte, length int) ([]byte, error) { + switch label { + case "client finished", "server finished", "master secret", "key expansion": + // These values are reserved and may not be used. + return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) + } + + seedLen := len(serverRandom) + len(clientRandom) + if context != nil { + seedLen += 2 + len(context) + } + seed := make([]byte, 0, seedLen) + + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + if context != nil { + if len(context) >= 1<<16 { + return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") + } + seed = append(seed, byte(len(context)>>8), byte(len(context))) + seed = append(seed, context...) + } + + keyMaterial := make([]byte, length) + prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) + return keyMaterial, nil + } +} diff --git a/pkg/tls/ticket.go b/pkg/tls/ticket.go new file mode 100644 index 000000000..6c1d20da2 --- /dev/null +++ b/pkg/tls/ticket.go @@ -0,0 +1,185 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "errors" + "io" + + "golang.org/x/crypto/cryptobyte" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + createdAt uint64 + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() +} + +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { + return false + } + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return false + } + m.certificates = append(m.certificates, cert) + } + return s.Empty() +} + +// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first +// version (revision = 0) doesn't carry any of the information needed for 0-RTT +// validation and the nonce is always empty. +type sessionStateTLS13 struct { + // uint8 version = 0x0304; + // uint8 revision = 0; + cipherSuite uint16 + createdAt uint64 + resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; +} + +func (m *sessionStateTLS13) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(VersionTLS13) + b.AddUint8(0) // revision + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.resumptionSecret) + }) + marshalCertificate(&b, m.certificate) + return b.BytesOrPanic() +} + +func (m *sessionStateTLS13) unmarshal(data []byte) bool { + *m = sessionStateTLS13{} + s := cryptobyte.String(data) + var version uint16 + var revision uint8 + return s.ReadUint16(&version) && + version == VersionTLS13 && + s.ReadUint8(&revision) && + revision == 0 && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint8LengthPrefixed(&s, &m.resumptionSecret) && + len(m.resumptionSecret) != 0 && + unmarshalCertificate(&s, &m.certificate) && + s.Empty() +} + +func (c *Conn) encryptTicket(state []byte) ([]byte, error) { + if len(c.ticketKeys) == 0 { + return nil, errors.New("tls: internal error: session ticket keys unavailable") + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.ticketKeys[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { + if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + + keyIndex := -1 + for i, candidateKey := range c.ticketKeys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + if keyIndex == -1 { + return nil, false + } + key := &c.ticketKeys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + plaintext = make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go new file mode 100644 index 000000000..5ede0b614 --- /dev/null +++ b/pkg/tls/tls.go @@ -0,0 +1,205 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package tls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "net" + "strings" + + "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" +) + +type conn interface { + Write([]byte) (int, error) + RemoteAddr() net.Addr +} + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(c conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) (*Conn, error) { + tlsconn := &Conn{ + conn: c, + config: config, + input: in, + sendBuf: out, + outBuf: []byte{0, 3, 3, 0, 0}, + } + + return tlsconn, nil +} +func Client(c conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) *Conn { + tlsconn := &Conn{ + conn: c, + config: config, + input: in, + sendBuf: out, + outBuf: []byte{0, 3, 3, 0, 0}, + isClient: true, + } + return tlsconn +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := ioutil.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := ioutil.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case ed25519.PublicKey: + priv, ok := cert.PrivateKey.(ed25519.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} From 2e073d2078c3db4a272c39012ce6044afaf88473 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Fri, 20 Jan 2023 16:29:57 +0000 Subject: [PATCH 02/34] 1. merge tls to go 1.20rc3 as close as possible 2. change the gnet API name for the TLS server & client 3. gnet TLS write returns the exact number of bytes written to the socket rather than the lenght of data. --- connection.go | 10 +- eventloop.go | 5 + internal/boring/doc.go | 19 + internal/boring/notboring.go | 123 +++++++ internal/boring/rand.go | 24 ++ internal/boring/sig/sig.go | 17 + internal/boring/sig/sig_amd64.s | 54 +++ internal/boring/sig/sig_other.s | 20 ++ pkg/tls/auth.go | 8 +- pkg/tls/auth_test.go | 168 +++++++++ pkg/tls/cache.go | 95 +++++ pkg/tls/cache_test.go | 117 ++++++ pkg/tls/cipher_suites.go | 331 ++++++++++++----- pkg/tls/common.go | 340 ++++++++++-------- pkg/tls/common_string.go | 116 ++++++ pkg/tls/conn.go | 553 ++++++++++++++++++----------- pkg/tls/generate_cert.go | 171 +++++++++ pkg/tls/handshake_client.go | 196 +++++----- pkg/tls/handshake_client_tls13.go | 45 ++- pkg/tls/handshake_messages.go | 17 +- pkg/tls/handshake_messages_test.go | 486 +++++++++++++++++++++++++ pkg/tls/handshake_server.go | 110 ++++-- pkg/tls/handshake_server_tls13.go | 74 ++-- pkg/tls/handshake_test.go | 530 +++++++++++++++++++++++++++ pkg/tls/handshake_unix_test.go | 18 + pkg/tls/key_agreement.go | 60 +++- pkg/tls/key_schedule.go | 102 ++---- pkg/tls/key_schedule_test.go | 175 +++++++++ pkg/tls/notboring.go | 20 ++ pkg/tls/prf.go | 2 +- pkg/tls/prf_test.go | 140 ++++++++ pkg/tls/tls.go | 67 ++-- pkg/tls/tls_test.go | 25 ++ 33 files changed, 3529 insertions(+), 709 deletions(-) create mode 100644 internal/boring/doc.go create mode 100644 internal/boring/notboring.go create mode 100644 internal/boring/rand.go create mode 100644 internal/boring/sig/sig.go create mode 100644 internal/boring/sig/sig_amd64.s create mode 100644 internal/boring/sig/sig_other.s create mode 100644 pkg/tls/auth_test.go create mode 100644 pkg/tls/cache.go create mode 100644 pkg/tls/cache_test.go create mode 100644 pkg/tls/common_string.go create mode 100644 pkg/tls/generate_cert.go create mode 100644 pkg/tls/handshake_messages_test.go create mode 100644 pkg/tls/handshake_test.go create mode 100644 pkg/tls/handshake_unix_test.go create mode 100644 pkg/tls/key_schedule_test.go create mode 100644 pkg/tls/notboring.go create mode 100644 pkg/tls/prf_test.go create mode 100644 pkg/tls/tls_test.go diff --git a/connection.go b/connection.go index 94d7e61ad..91cab70be 100644 --- a/connection.go +++ b/connection.go @@ -134,13 +134,13 @@ func (c *conn) write(data []byte) (n int, err error) { if c.tlsconn != nil { // use tls to encrypt the data before sending it - c.tlsconn.Write(data) + n, _ = c.tlsconn.Write(data) // err = c.loop.poller.ModReadWrite(c.pollAttachment) // n = 0 // also working err = c.loop.write(c) return - } + } // If there is pending data in outbound buffer, the current data ought to be appended to the outbound buffer // for maintaining the sequence of network packets. @@ -494,11 +494,13 @@ func (c *conn) Close() error { } func (c *conn) UpgradeTLS(config *tls.Config) (err error) { - c.tlsconn, err = tls.Server(c, &c.inboundBuffer, c.outboundBuffer, config.Clone()) + c.tlsconn = tls.ServerGnet(c, &c.inboundBuffer, c.outboundBuffer, config.Clone()) //很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 if c.inboundBuffer.Len() > 0 { - c.tlsconn.RawWrite(c.inboundBuffer.Bytes()) + head, tail := c.inboundBuffer.Peek(-1) + c.tlsconn.RawWrite(head) + c.tlsconn.RawWrite(tail) c.inboundBuffer.Reset() if err := c.tlsconn.Handshake(); err != nil { return err diff --git a/eventloop.go b/eventloop.go index 7b0fff0bc..70dc8123f 100644 --- a/eventloop.go +++ b/eventloop.go @@ -228,6 +228,11 @@ func (el *eventloop) closeConn(c *conn, err error) (rerr error) { return } + // clost the TLS connection by sending the alert + if c.tlsconn != nil { + c.tlsconn.Close() + } + // Send residual data in buffer back to the peer before actually closing the connection. if !c.outboundBuffer.IsEmpty() { for !c.outboundBuffer.IsEmpty() { diff --git a/internal/boring/doc.go b/internal/boring/doc.go new file mode 100644 index 000000000..6060fe595 --- /dev/null +++ b/internal/boring/doc.go @@ -0,0 +1,19 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package boring provides access to BoringCrypto implementation functions. +// Check the constant Enabled to find out whether BoringCrypto is available. +// If BoringCrypto is not available, the functions in this package all panic. +package boring + +// Enabled reports whether BoringCrypto is available. +// When enabled is false, all functions in this package panic. +// +// BoringCrypto is only available on linux/amd64 systems. +const Enabled = available + +// A BigInt is the raw words from a BigInt. +// This definition allows us to avoid importing math/big. +// Conversion between BigInt and *big.Int is in crypto/internal/boring/bbig. +type BigInt []uint diff --git a/internal/boring/notboring.go b/internal/boring/notboring.go new file mode 100644 index 000000000..6341d5b16 --- /dev/null +++ b/internal/boring/notboring.go @@ -0,0 +1,123 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !(boringcrypto && linux && (amd64 || arm64) && !android && !cmd_go_bootstrap && !msan && cgo) + +package boring + +import ( + "crypto" + "crypto/cipher" + "hash" + + "github.com/panjf2000/gnet/v2/internal/boring/sig" +) + +const available = false + +// Unreachable marks code that should be unreachable +// when BoringCrypto is in use. It is a no-op without BoringCrypto. +func Unreachable() { + // Code that's unreachable when using BoringCrypto + // is exactly the code we want to detect for reporting + // standard Go crypto. + sig.StandardCrypto() +} + +// UnreachableExceptTests marks code that should be unreachable +// when BoringCrypto is in use. It is a no-op without BoringCrypto. +func UnreachableExceptTests() {} + +type randReader int + +func (randReader) Read(b []byte) (int, error) { panic("boringcrypto: not available") } + +const RandReader = randReader(0) + +func NewSHA1() hash.Hash { panic("boringcrypto: not available") } +func NewSHA224() hash.Hash { panic("boringcrypto: not available") } +func NewSHA256() hash.Hash { panic("boringcrypto: not available") } +func NewSHA384() hash.Hash { panic("boringcrypto: not available") } +func NewSHA512() hash.Hash { panic("boringcrypto: not available") } + +func SHA1([]byte) [20]byte { panic("boringcrypto: not available") } +func SHA224([]byte) [28]byte { panic("boringcrypto: not available") } +func SHA256([]byte) [32]byte { panic("boringcrypto: not available") } +func SHA384([]byte) [48]byte { panic("boringcrypto: not available") } +func SHA512([]byte) [64]byte { panic("boringcrypto: not available") } + +func NewHMAC(h func() hash.Hash, key []byte) hash.Hash { panic("boringcrypto: not available") } + +func NewAESCipher(key []byte) (cipher.Block, error) { panic("boringcrypto: not available") } +func NewGCMTLS(cipher.Block) (cipher.AEAD, error) { panic("boringcrypto: not available") } + +type PublicKeyECDSA struct{ _ int } +type PrivateKeyECDSA struct{ _ int } + +func GenerateKeyECDSA(curve string) (X, Y, D BigInt, err error) { + panic("boringcrypto: not available") +} +func NewPrivateKeyECDSA(curve string, X, Y, D BigInt) (*PrivateKeyECDSA, error) { + panic("boringcrypto: not available") +} +func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) { + panic("boringcrypto: not available") +} +func SignMarshalECDSA(priv *PrivateKeyECDSA, hash []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func VerifyECDSA(pub *PublicKeyECDSA, hash []byte, sig []byte) bool { + panic("boringcrypto: not available") +} + +type PublicKeyRSA struct{ _ int } +type PrivateKeyRSA struct{ _ int } + +func DecryptRSAOAEP(h, mgfHash hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func EncryptRSAOAEP(h, mgfHash hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { + panic("boringcrypto: not available") +} +func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) { + panic("boringcrypto: not available") +} +func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) { panic("boringcrypto: not available") } +func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { + panic("boringcrypto: not available") +} +func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) { + panic("boringcrypto: not available") +} +func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { + panic("boringcrypto: not available") +} +func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error { + panic("boringcrypto: not available") +} + +type PublicKeyECDH struct{} +type PrivateKeyECDH struct{} + +func ECDH(*PrivateKeyECDH, *PublicKeyECDH) ([]byte, error) { panic("boringcrypto: not available") } +func GenerateKeyECDH(string) (*PrivateKeyECDH, []byte, error) { panic("boringcrypto: not available") } +func NewPrivateKeyECDH(string, []byte) (*PrivateKeyECDH, error) { panic("boringcrypto: not available") } +func NewPublicKeyECDH(string, []byte) (*PublicKeyECDH, error) { panic("boringcrypto: not available") } +func (*PublicKeyECDH) Bytes() []byte { panic("boringcrypto: not available") } +func (*PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { panic("boringcrypto: not available") } diff --git a/internal/boring/rand.go b/internal/boring/rand.go new file mode 100644 index 000000000..7639c0190 --- /dev/null +++ b/internal/boring/rand.go @@ -0,0 +1,24 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build boringcrypto && linux && (amd64 || arm64) && !android && !cmd_go_bootstrap && !msan + +package boring + +// #include "goboringcrypto.h" +import "C" +import "unsafe" + +type randReader int + +func (randReader) Read(b []byte) (int, error) { + // Note: RAND_bytes should never fail; the return value exists only for historical reasons. + // We check it even so. + if len(b) > 0 && C._goboringcrypto_RAND_bytes((*C.uint8_t)(unsafe.Pointer(&b[0])), C.size_t(len(b))) == 0 { + return 0, fail("RAND_bytes") + } + return len(b), nil +} + +const RandReader = randReader(0) diff --git a/internal/boring/sig/sig.go b/internal/boring/sig/sig.go new file mode 100644 index 000000000..716c03c5e --- /dev/null +++ b/internal/boring/sig/sig.go @@ -0,0 +1,17 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sig holds “code signatures” that can be called +// and will result in certain code sequences being linked into +// the final binary. The functions themselves are no-ops. +package sig + +// BoringCrypto indicates that the BoringCrypto module is present. +func BoringCrypto() + +// FIPSOnly indicates that package crypto/tls/fipsonly is present. +func FIPSOnly() + +// StandardCrypto indicates that standard Go crypto is present. +func StandardCrypto() diff --git a/internal/boring/sig/sig_amd64.s b/internal/boring/sig/sig_amd64.s new file mode 100644 index 000000000..64e3462e4 --- /dev/null +++ b/internal/boring/sig/sig_amd64.s @@ -0,0 +1,54 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// These functions are no-ops, but you can search for their implementations +// to find out whether they are linked into a particular binary. +// +// Each function consists of a two-byte jump over the next 29-bytes, +// then a 5-byte indicator sequence unlikely to occur in real x86 instructions, +// then a randomly-chosen 24-byte sequence, and finally a return instruction +// (the target of the jump). +// +// These sequences are known to rsc.io/goversion. + +#define START \ + BYTE $0xEB; BYTE $0x1D; BYTE $0xF4; BYTE $0x48; BYTE $0xF4; BYTE $0x4B; BYTE $0xF4 + +#define END \ + BYTE $0xC3 + +// BoringCrypto indicates that BoringCrypto (in particular, its func init) is present. +TEXT ·BoringCrypto(SB),NOSPLIT,$0 + START + BYTE $0xB3; BYTE $0x32; BYTE $0xF5; BYTE $0x28; + BYTE $0x13; BYTE $0xA3; BYTE $0xB4; BYTE $0x50; + BYTE $0xD4; BYTE $0x41; BYTE $0xCC; BYTE $0x24; + BYTE $0x85; BYTE $0xF0; BYTE $0x01; BYTE $0x45; + BYTE $0x4E; BYTE $0x92; BYTE $0x10; BYTE $0x1B; + BYTE $0x1D; BYTE $0x2F; BYTE $0x19; BYTE $0x50; + END + +// StandardCrypto indicates that standard Go crypto is present. +TEXT ·StandardCrypto(SB),NOSPLIT,$0 + START + BYTE $0xba; BYTE $0xee; BYTE $0x4d; BYTE $0xfa; + BYTE $0x98; BYTE $0x51; BYTE $0xca; BYTE $0x56; + BYTE $0xa9; BYTE $0x11; BYTE $0x45; BYTE $0xe8; + BYTE $0x3e; BYTE $0x99; BYTE $0xc5; BYTE $0x9c; + BYTE $0xf9; BYTE $0x11; BYTE $0xcb; BYTE $0x8e; + BYTE $0x80; BYTE $0xda; BYTE $0xf1; BYTE $0x2f; + END + +// FIPSOnly indicates that crypto/tls/fipsonly is present. +TEXT ·FIPSOnly(SB),NOSPLIT,$0 + START + BYTE $0x36; BYTE $0x3C; BYTE $0xB9; BYTE $0xCE; + BYTE $0x9D; BYTE $0x68; BYTE $0x04; BYTE $0x7D; + BYTE $0x31; BYTE $0xF2; BYTE $0x8D; BYTE $0x32; + BYTE $0x5D; BYTE $0x5C; BYTE $0xA5; BYTE $0x87; + BYTE $0x3F; BYTE $0x5D; BYTE $0x80; BYTE $0xCA; + BYTE $0xF6; BYTE $0xD6; BYTE $0x15; BYTE $0x1B; + END diff --git a/internal/boring/sig/sig_other.s b/internal/boring/sig/sig_other.s new file mode 100644 index 000000000..2bbb1df30 --- /dev/null +++ b/internal/boring/sig/sig_other.s @@ -0,0 +1,20 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// These functions are no-ops. +// On amd64 they have recognizable implementations, so that you can +// search a particular binary to see if they are present. +// On other platforms (those using this source file), they don't. + +//go:build !amd64 +// +build !amd64 + +TEXT ·BoringCrypto(SB),$0 + RET + +TEXT ·FIPSOnly(SB),$0 + RET + +TEXT ·StandardCrypto(SB),$0 + RET diff --git a/pkg/tls/auth.go b/pkg/tls/auth.go index ad5f9a2e4..7c5675c6d 100644 --- a/pkg/tls/auth.go +++ b/pkg/tls/auth.go @@ -155,9 +155,9 @@ var rsaSignatureSchemes = []struct { {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, - // PKCS#1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires // emLen >= len(prefix) + hLen + 11 - // TLS 1.3 dropped support for PKCS#1 v1.5 in favor of RSA-PSS. + // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, @@ -169,6 +169,7 @@ var rsaSignatureSchemes = []struct { // and optionally filtered by its explicit SupportedSignatureAlgorithms. // // This function must be kept in sync with supportedSignatureAlgorithms. +// FIPS filtering is applied in the caller, selectSignatureScheme. func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { priv, ok := cert.PrivateKey.(crypto.Signer) if !ok { @@ -241,6 +242,9 @@ func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureSche // Pick signature scheme in the peer's preference order, as our // preference order is not configurable. for _, preferredAlg := range peerAlgs { + if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) { + continue + } if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { return preferredAlg, nil } diff --git a/pkg/tls/auth_test.go b/pkg/tls/auth_test.go new file mode 100644 index 000000000..c23d93f3c --- /dev/null +++ b/pkg/tls/auth_test.go @@ -0,0 +1,168 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "testing" +) + +func TestSignatureSelection(t *testing.T) { + rsaCert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + pkcs1Cert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, + } + ecdsaCert := &Certificate{ + Certificate: [][]byte{testP256Certificate}, + PrivateKey: testP256PrivateKey, + } + ed25519Cert := &Certificate{ + Certificate: [][]byte{testEd25519Certificate}, + PrivateKey: testEd25519PrivateKey, + } + + tests := []struct { + cert *Certificate + peerSigAlgs []SignatureScheme + tlsVersion uint16 + + expectedSigAlg SignatureScheme + expectedSigType uint8 + expectedHash crypto.Hash + }{ + {rsaCert, []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1}, + {rsaCert, []SignatureScheme{PKCS1WithSHA512, PKCS1WithSHA1}, VersionTLS12, PKCS1WithSHA512, signaturePKCS1v15, crypto.SHA512}, + {rsaCert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PSSWithSHA256, signatureRSAPSS, crypto.SHA256}, + {pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA256, signaturePKCS1v15, crypto.SHA256}, + {rsaCert, []SignatureScheme{PSSWithSHA384, PKCS1WithSHA1}, VersionTLS13, PSSWithSHA384, signatureRSAPSS, crypto.SHA384}, + {ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1}, + {ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS12, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256}, + {ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS13, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256}, + {ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS12, Ed25519, signatureEd25519, directSigning}, + {ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS13, Ed25519, signatureEd25519, directSigning}, + + // TLS 1.2 without signature_algorithms extension + {rsaCert, nil, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1}, + {ecdsaCert, nil, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1}, + + // TLS 1.2 does not restrict the ECDSA curve (our ecdsaCert is P-256) + {ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS12, ECDSAWithP384AndSHA384, signatureECDSA, crypto.SHA384}, + } + + for testNo, test := range tests { + sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs) + if err != nil { + t.Errorf("test[%d]: unexpected selectSignatureScheme error: %v", testNo, err) + } + if test.expectedSigAlg != sigAlg { + t.Errorf("test[%d]: expected signature scheme %v, got %v", testNo, test.expectedSigAlg, sigAlg) + } + sigType, hashFunc, err := typeAndHashFromSignatureScheme(sigAlg) + if err != nil { + t.Errorf("test[%d]: unexpected typeAndHashFromSignatureScheme error: %v", testNo, err) + } + if test.expectedSigType != sigType { + t.Errorf("test[%d]: expected signature algorithm %#x, got %#x", testNo, test.expectedSigType, sigType) + } + if test.expectedHash != hashFunc { + t.Errorf("test[%d]: expected hash function %#x, got %#x", testNo, test.expectedHash, hashFunc) + } + } + + brokenCert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + SupportedSignatureAlgorithms: []SignatureScheme{Ed25519}, + } + + badTests := []struct { + cert *Certificate + peerSigAlgs []SignatureScheme + tlsVersion uint16 + }{ + {rsaCert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12}, + {ecdsaCert, []SignatureScheme{PKCS1WithSHA256, PKCS1WithSHA1}, VersionTLS12}, + {rsaCert, []SignatureScheme{0}, VersionTLS12}, + {ed25519Cert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12}, + {ecdsaCert, []SignatureScheme{Ed25519}, VersionTLS12}, + {brokenCert, []SignatureScheme{Ed25519}, VersionTLS12}, + {brokenCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS12}, + // RFC 5246, Section 7.4.1.4.1, says to only consider {sha1,ecdsa} as + // default when the extension is missing, and RFC 8422 does not update + // it. Anyway, if a stack supports Ed25519 it better support sigalgs. + {ed25519Cert, nil, VersionTLS12}, + // TLS 1.3 has no default signature_algorithms. + {rsaCert, nil, VersionTLS13}, + {ecdsaCert, nil, VersionTLS13}, + {ed25519Cert, nil, VersionTLS13}, + // Wrong curve, which TLS 1.3 checks + {ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS13}, + // TLS 1.3 does not support PKCS1v1.5 or SHA-1. + {rsaCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS13}, + {pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS13}, + {ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS13}, + // The key can be too small for the hash. + {rsaCert, []SignatureScheme{PSSWithSHA512}, VersionTLS12}, + } + + for testNo, test := range badTests { + sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs) + if err == nil { + t.Errorf("test[%d]: unexpected success, got %v", testNo, sigAlg) + } + } +} + +func TestLegacyTypeAndHash(t *testing.T) { + sigType, hashFunc, err := legacyTypeAndHashFromPublicKey(testRSAPrivateKey.Public()) + if err != nil { + t.Errorf("RSA: unexpected error: %v", err) + } + if expectedSigType := signaturePKCS1v15; expectedSigType != sigType { + t.Errorf("RSA: expected signature type %#x, got %#x", expectedSigType, sigType) + } + if expectedHashFunc := crypto.MD5SHA1; expectedHashFunc != hashFunc { + t.Errorf("RSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc) + } + + sigType, hashFunc, err = legacyTypeAndHashFromPublicKey(testECDSAPrivateKey.Public()) + if err != nil { + t.Errorf("ECDSA: unexpected error: %v", err) + } + if expectedSigType := signatureECDSA; expectedSigType != sigType { + t.Errorf("ECDSA: expected signature type %#x, got %#x", expectedSigType, sigType) + } + if expectedHashFunc := crypto.SHA1; expectedHashFunc != hashFunc { + t.Errorf("ECDSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc) + } + + // Ed25519 is not supported by TLS 1.0 and 1.1. + _, _, err = legacyTypeAndHashFromPublicKey(testEd25519PrivateKey.Public()) + if err == nil { + t.Errorf("Ed25519: unexpected success") + } +} + +// TestSupportedSignatureAlgorithms checks that all supportedSignatureAlgorithms +// have valid type and hash information. +func TestSupportedSignatureAlgorithms(t *testing.T) { + for _, sigAlg := range supportedSignatureAlgorithms() { + sigType, hash, err := typeAndHashFromSignatureScheme(sigAlg) + if err != nil { + t.Errorf("%v: unexpected error: %v", sigAlg, err) + } + if sigType == 0 { + t.Errorf("%v: missing signature type", sigAlg) + } + if hash == 0 && sigAlg != Ed25519 { + t.Errorf("%v: missing hash", sigAlg) + } + } +} diff --git a/pkg/tls/cache.go b/pkg/tls/cache.go new file mode 100644 index 000000000..fc8f2c084 --- /dev/null +++ b/pkg/tls/cache.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto/x509" + "runtime" + "sync" + "sync/atomic" +) + +type cacheEntry struct { + refs atomic.Int64 + cert *x509.Certificate +} + +// certCache implements an intern table for reference counted x509.Certificates, +// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This +// allows for a single x509.Certificate to be kept in memory and referenced from +// multiple Conns. Returned references should not be mutated by callers. Certificates +// are still safe to use after they are removed from the cache. +// +// Certificates are returned wrapped in a activeCert struct that should be held by +// the caller. When references to the activeCert are freed, the number of references +// to the certificate in the cache is decremented. Once the number of references +// reaches zero, the entry is evicted from the cache. +// +// The main difference between this implementation and CRYPTO_BUFFER_POOL is that +// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data, +// rather than specific structures. Since we only care about x509.Certificates, +// certCache is implemented as a specific cache, rather than a generic one. +// +// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h +// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c +// for the BoringSSL reference. +type certCache struct { + sync.Map +} + +var clientCertCache = new(certCache) + +// activeCert is a handle to a certificate held in the cache. Once there are +// no alive activeCerts for a given certificate, the certificate is removed +// from the cache by a finalizer. +type activeCert struct { + cert *x509.Certificate +} + +// active increments the number of references to the entry, wraps the +// certificate in the entry in a activeCert, and sets the finalizer. +// +// Note that there is a race between active and the finalizer set on the +// returned activeCert, triggered if active is called after the ref count is +// decremented such that refs may be > 0 when evict is called. We consider this +// safe, since the caller holding an activeCert for an entry that is no longer +// in the cache is fine, with the only side effect being the memory overhead of +// there being more than one distinct reference to a certificate alive at once. +func (cc *certCache) active(e *cacheEntry) *activeCert { + e.refs.Add(1) + a := &activeCert{e.cert} + runtime.SetFinalizer(a, func(_ *activeCert) { + if e.refs.Add(-1) == 0 { + cc.evict(e) + } + }) + return a +} + +// evict removes a cacheEntry from the cache. +func (cc *certCache) evict(e *cacheEntry) { + cc.Delete(string(e.cert.Raw)) +} + +// newCert returns a x509.Certificate parsed from der. If there is already a copy +// of the certificate in the cache, a reference to the existing certificate will +// be returned. Otherwise, a fresh certificate will be added to the cache, and +// the reference returned. The returned reference should not be mutated. +func (cc *certCache) newCert(der []byte) (*activeCert, error) { + if entry, ok := cc.Load(string(der)); ok { + return cc.active(entry.(*cacheEntry)), nil + } + + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + entry := &cacheEntry{cert: cert} + if entry, loaded := cc.LoadOrStore(string(der), entry); loaded { + return cc.active(entry.(*cacheEntry)), nil + } + return cc.active(entry), nil +} diff --git a/pkg/tls/cache_test.go b/pkg/tls/cache_test.go new file mode 100644 index 000000000..284673419 --- /dev/null +++ b/pkg/tls/cache_test.go @@ -0,0 +1,117 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "encoding/pem" + "fmt" + "runtime" + "testing" + "time" +) + +func TestCertCache(t *testing.T) { + cc := certCache{} + p, _ := pem.Decode([]byte(rsaCertPEM)) + if p == nil { + t.Fatal("Failed to decode certificate") + } + + certA, err := cc.newCert(p.Bytes) + if err != nil { + t.Fatalf("newCert failed: %s", err) + } + certB, err := cc.newCert(p.Bytes) + if err != nil { + t.Fatalf("newCert failed: %s", err) + } + if certA.cert != certB.cert { + t.Fatal("newCert returned a unique reference for a duplicate certificate") + } + + if entry, ok := cc.Load(string(p.Bytes)); !ok { + t.Fatal("cache does not contain expected entry") + } else { + if refs := entry.(*cacheEntry).refs.Load(); refs != 2 { + t.Fatalf("unexpected number of references: got %d, want 2", refs) + } + } + + timeoutRefCheck := func(t *testing.T, key string, count int64) { + t.Helper() + c := time.After(4 * time.Second) + for { + select { + case <-c: + t.Fatal("timed out waiting for expected ref count") + default: + e, ok := cc.Load(key) + if !ok && count != 0 { + t.Fatal("cache does not contain expected key") + } else if count == 0 && !ok { + return + } + + if e.(*cacheEntry).refs.Load() == count { + return + } + } + } + } + + // Keep certA alive until at least now, so that we can + // purposefully nil it and force the finalizer to be + // called. + runtime.KeepAlive(certA) + certA = nil + runtime.GC() + + timeoutRefCheck(t, string(p.Bytes), 1) + + // Keep certB alive until at least now, so that we can + // purposefully nil it and force the finalizer to be + // called. + runtime.KeepAlive(certB) + certB = nil + runtime.GC() + + timeoutRefCheck(t, string(p.Bytes), 0) +} + +func BenchmarkCertCache(b *testing.B) { + p, _ := pem.Decode([]byte(rsaCertPEM)) + if p == nil { + b.Fatal("Failed to decode certificate") + } + + cc := certCache{} + b.ReportAllocs() + b.ResetTimer() + // We expect that calling newCert additional times after + // the initial call should not cause additional allocations. + for extra := 0; extra < 4; extra++ { + b.Run(fmt.Sprint(extra), func(b *testing.B) { + actives := make([]*activeCert, extra+1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + var err error + actives[0], err = cc.newCert(p.Bytes) + if err != nil { + b.Fatal(err) + } + for j := 0; j < extra; j++ { + actives[j+1], err = cc.newCert(p.Bytes) + if err != nil { + b.Fatal(err) + } + } + for j := 0; j < extra+1; j++ { + actives[j] = nil + } + runtime.GC() + } + }) + } +} diff --git a/pkg/tls/cipher_suites.go b/pkg/tls/cipher_suites.go index 62ec2bdb0..3077e0ab7 100644 --- a/pkg/tls/cipher_suites.go +++ b/pkg/tls/cipher_suites.go @@ -13,11 +13,13 @@ import ( "crypto/rc4" "crypto/sha1" "crypto/sha256" - "crypto/x509" "fmt" "hash" + "runtime" "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/sys/cpu" + "github.com/panjf2000/gnet/v2/internal/boring" ) // CipherSuite is a TLS cipher suite. Note that most functions in this package @@ -36,7 +38,7 @@ type CipherSuite struct { } var ( - supportedUpToTLS12 = []uint16{VersionTLS11, VersionTLS12} + supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} supportedOnlyTLS12 = []uint16{VersionTLS12} supportedOnlyTLS13 = []uint16{VersionTLS13} ) @@ -46,10 +48,10 @@ var ( // InsecureCipherSuites. // // The list is sorted by ID. Note that the default cipher suites selected by -// this package might depend on logic that can't be captured by a static list. +// this package might depend on logic that can't be captured by a static list, +// and might not match those returned by this function. func CipherSuites() []*CipherSuite { return []*CipherSuite{ - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, @@ -61,7 +63,6 @@ func CipherSuites() []*CipherSuite { {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, @@ -79,13 +80,15 @@ func CipherSuites() []*CipherSuite { // Most applications should not use the cipher suites in this list, and should // only use those returned by CipherSuites. func InsecureCipherSuites() []*CipherSuite { - // RC4 suites are broken because RC4 is. - // CBC-SHA256 suites have no Lucky13 countermeasures. + // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See + // cipherSuitesPreferenceOrder for details. return []*CipherSuite{ {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, } @@ -108,25 +111,6 @@ func CipherSuiteName(id uint16) string { return fmt.Sprintf("0x%04X", id) } -// a keyAgreement implements the client and server side of a TLS key agreement -// protocol by generating and processing key exchange messages. -type keyAgreement interface { - // On the server side, the first two methods are called in order. - - // In the case that the key agreement protocol doesn't use a - // ServerKeyExchange message, generateServerKeyExchange can return nil, - // nil. - generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) - processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) - - // On the client side, the next two methods are called in order. - - // This method may not be called if the server doesn't send a - // ServerKeyExchange message. - processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error - generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) -} - const ( // suiteECDHE indicates that the cipher suite involves elliptic curve // Diffie-Hellman. This means that it should only be selected when the @@ -144,12 +128,10 @@ const ( // suiteSHA384 indicates that the cipher suite uses SHA384 as the // handshake hash. suiteSHA384 - // suiteDefaultOff indicates that this cipher suite is not included by - // default. - suiteDefaultOff ) -// A cipherSuite is a specific combination of key agreement, cipher and MAC function. +// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange +// mechanism, as well as the cipher+MAC pair or the AEAD. type cipherSuite struct { id uint16 // the lengths, in bytes, of the key material needed for each component. @@ -160,41 +142,37 @@ type cipherSuite struct { // flags is a bitmask of the suite* values, above. flags int cipher func(key, iv []byte, isRead bool) interface{} - mac func(version uint16, macKey []byte) macFunction + mac func(key []byte) hash.Hash aead func(key, fixedNonce []byte) aead } -var cipherSuites = []*cipherSuite{ - // Ciphersuite order is chosen so that ECDHE comes before plain RSA and - // AEADs are the top preference. +var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, - - // RC4-based cipher suites are disabled by default. - {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, } -// selectCipherSuite returns the first cipher suite from ids which is also in -// supportedIDs and passes the ok filter. +// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which +// is also in supportedIDs and passes the ok filter. func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { for _, id := range ids { candidate := cipherSuiteByID(id) @@ -220,13 +198,208 @@ type cipherSuiteTLS13 struct { hash crypto.Hash } -var cipherSuitesTLS13 = []*cipherSuiteTLS13{ +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, } -func cipherRC4(key, iv []byte, isRead bool) interface{} { +// cipherSuitesPreferenceOrder is the order in which we'll select (on the +// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. +// +// Cipher suites are filtered but not reordered based on the application and +// peer's preferences, meaning we'll never select a suite lower in this list if +// any higher one is available. This makes it more defensible to keep weaker +// cipher suites enabled, especially on the server side where we get the last +// word, since there are no known downgrade attacks on cipher suites selection. +// +// The list is sorted by applying the following priority rules, stopping at the +// first (most important) applicable one: +// +// - Anything else comes before RC4 +// +// RC4 has practically exploitable biases. See https://www.rc4nomore.com. +// +// - Anything else comes before CBC_SHA256 +// +// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 +// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. +// +// - Anything else comes before 3DES +// +// 3DES has 64-bit blocks, which makes it fundamentally susceptible to +// birthday attacks. See https://sweet32.info. +// +// - ECDHE comes before anything else +// +// Once we got the broken stuff out of the way, the most important +// property a cipher suite can have is forward secrecy. We don't +// implement FFDHE, so that means ECDHE. +// +// - AEADs come before CBC ciphers +// +// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites +// are fundamentally fragile, and suffered from an endless sequence of +// padding oracle attacks. See https://eprint.iacr.org/2015/1129, +// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and +// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. +// +// - AES comes before ChaCha20 +// +// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster +// than ChaCha20Poly1305. +// +// When AES hardware is not available, AES-128-GCM is one or more of: much +// slower, way more complex, and less safe (because not constant time) +// than ChaCha20Poly1305. +// +// We use this list if we think both peers have AES hardware, and +// cipherSuitesPreferenceOrderNoAES otherwise. +// +// - AES-128 comes before AES-256 +// +// The only potential advantages of AES-256 are better multi-target +// margins, and hypothetical post-quantum properties. Neither apply to +// TLS, and AES-256 is slower due to its four extra rounds (which don't +// contribute to the advantages above). +// +// - ECDSA comes before RSA +// +// The relative order of ECDSA and RSA cipher suites doesn't matter, +// as they depend on the certificate. Pick one to get a stable order. +var cipherSuitesPreferenceOrder = []uint16{ + // AEADs w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // CBC w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + + // AEADs w/o ECDHE + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + + // CBC w/o ECDHE + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + + // 3DES + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var cipherSuitesPreferenceOrderNoAES = []uint16{ + // ChaCha20Poly1305 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // AES-GCM w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + + // The rest of cipherSuitesPreferenceOrder. + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +// disabledCipherSuites are not used unless explicitly listed in +// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. +var disabledCipherSuites = []uint16{ + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var ( + defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) + defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] +) + +// defaultCipherSuitesTLS13 is also the preference order, since there are no +// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as +// cipherSuitesPreferenceOrder applies. +var defaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, +} + +var defaultCipherSuitesTLS13NoAES = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, +} + +var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && + (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || + runtime.GOARCH == "arm64" && hasGCMAsmARM64 || + runtime.GOARCH == "s390x" && hasGCMAsmS390X +) + +var aesgcmCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, + // TLS 1.3 + TLS_AES_128_GCM_SHA256: true, + TLS_AES_256_GCM_SHA384: true, +} + +var nonAESGCMAEADCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, + // TLS 1.3 + TLS_CHACHA20_POLY1305_SHA256: true, +} + +// aesgcmPreferred returns whether the first known cipher in the preference list +// is an AES-GCM cipher, implying the peer has hardware support for it. +func aesgcmPreferred(ciphers []uint16) bool { + for _, cID := range ciphers { + if c := cipherSuiteByID(cID); c != nil { + return aesgcmCiphers[cID] + } + if c := cipherSuiteTLS13ByID(cID); c != nil { + return aesgcmCiphers[cID] + } + } + return false +} + +func cipherRC4(key, iv []byte, isRead bool) any { cipher, _ := rc4.NewCipher(key) return cipher } @@ -247,24 +420,21 @@ func cipherAES(key, iv []byte, isRead bool) interface{} { return cipher.NewCBCEncrypter(block, iv) } -// macSHA1 returns a macFunction for the given protocol version. -func macSHA1(version uint16, key []byte) macFunction { - return tls10MAC{h: hmac.New(newConstantTimeHash(sha1.New), key)} -} - -// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 -// so the given version is ignored. -func macSHA256(version uint16, key []byte) macFunction { - return tls10MAC{h: hmac.New(sha256.New, key)} +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { + h := sha1.New + // The BoringCrypto SHA1 does not have a constant-time + // checksum function, so don't try to use it. + if !boring.Enabled { + h = newConstantTimeHash(h) + } + return hmac.New(h, key) } -type macFunction interface { - // Size returns the length of the MAC. - Size() int - // MAC appends the MAC of (seq, header, data) to out. The extra data is fed - // into the MAC after obtaining the result to normalize timing. The result - // is only valid until the next invocation of MAC as the buffer is reused. - MAC(seq, header, data, extra []byte) []byte +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) } type aead interface { @@ -303,7 +473,7 @@ func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([ return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) } -// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce // before each call. type xorNonceAEAD struct { nonceMask [aeadNonceLength]byte @@ -346,7 +516,13 @@ func aeadAESGCM(key, noncePrefix []byte) aead { if err != nil { panic(err) } - aead, err := cipher.NewGCM(aes) + var aead cipher.AEAD + if boring.Enabled { + aead, err = boring.NewGCMTLS(aes) + } else { + boring.Unreachable() + aead, err = cipher.NewGCM(aes) + } if err != nil { panic(err) } @@ -406,32 +582,21 @@ func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + boring.Unreachable() return func() hash.Hash { return &cthWrapper{h().(constantTimeHash)} } } // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. -type tls10MAC struct { - h hash.Hash - buf []byte -} - -func (s tls10MAC) Size() int { - return s.h.Size() -} - -// MAC is guaranteed to take constant time, as long as -// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into -// the MAC, but is only provided to make the timing profile constant. -func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte { - s.h.Reset() - s.h.Write(seq) - s.h.Write(header) - s.h.Write(data) - res := s.h.Sum(s.buf[:0]) +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) if extra != nil { - s.h.Write(extra) + h.Write(extra) } return res } diff --git a/pkg/tls/common.go b/pkg/tls/common.go index cf9e2235d..007f0f47b 100644 --- a/pkg/tls/common.go +++ b/pkg/tls/common.go @@ -7,6 +7,7 @@ package tls import ( "bytes" "container/list" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -18,11 +19,10 @@ import ( "errors" "fmt" "io" + "net" "strings" "sync" "time" - - "golang.org/x/sys/cpu" ) const ( @@ -171,11 +171,11 @@ const ( // hash function associated with the Ed25519 signature scheme. var directSigning crypto.Hash = 0 -// supportedSignatureAlgorithms contains the signature and hash algorithms that +// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that // the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ // CertificateRequest. The two fields are merged to match with TLS 1.3. // Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. -var supportedSignatureAlgorithms = []SignatureScheme{ +var defaultSupportedSignatureAlgorithms = []SignatureScheme{ PSSWithSHA256, ECDSAWithP256AndSHA256, Ed25519, @@ -213,28 +213,73 @@ var testingOnlyForceDowngradeCanary bool // ConnectionState records basic TLS details about the connection. type ConnectionState struct { - Version uint16 // TLS version used by the connection (e.g. VersionTLS12) - HandshakeComplete bool // TLS handshake is complete - DidResume bool // connection resumes a previous TLS connection - CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...) - NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) - NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) - ServerName string // server name requested by client, if any - PeerCertificates []*x509.Certificate // certificate chain presented by remote peer - VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates - SignedCertificateTimestamps [][]byte // SCTs from the peer, if any - OCSPResponse []byte // stapled OCSP response from peer, if any + // Version is the TLS version used by the connection (e.g. VersionTLS12). + Version uint16 - // ekm is a closure exposed via ExportKeyingMaterial. - ekm func(label string, context []byte, length int) ([]byte, error) + // HandshakeComplete is true if the handshake has concluded. + HandshakeComplete bool + + // DidResume is true if this connection was successfully resumed from a + // previous session with a session ticket or similar mechanism. + DidResume bool + + // CipherSuite is the cipher suite negotiated for the connection (e.g. + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). + CipherSuite uint16 + + // NegotiatedProtocol is the application protocol negotiated with ALPN. + NegotiatedProtocol string + + // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. + // + // Deprecated: this value is always true. + NegotiatedProtocolIsMutual bool + + // ServerName is the value of the Server Name Indication extension sent by + // the client. It's available both on the server and on the client side. + ServerName string + + // PeerCertificates are the parsed certificates sent by the peer, in the + // order in which they were sent. The first element is the leaf certificate + // that the connection is verified against. + // + // On the client side, it can't be empty. On the server side, it can be + // empty if Config.ClientAuth is not RequireAnyClientCert or + // RequireAndVerifyClientCert. + // + // PeerCertificates and its contents should not be modified. + PeerCertificates []*x509.Certificate + + // VerifiedChains is a list of one or more chains where the first element is + // PeerCertificates[0] and the last element is from Config.RootCAs (on the + // client side) or Config.ClientCAs (on the server side). + // + // On the client side, it's set if Config.InsecureSkipVerify is false. On + // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven + // (and the peer provided a certificate) or RequireAndVerifyClientCert. + // + // VerifiedChains and its contents should not be modified. + VerifiedChains [][]*x509.Certificate + + // SignedCertificateTimestamps is a list of SCTs provided by the peer + // through the TLS handshake for the leaf certificate, if any. + SignedCertificateTimestamps [][]byte + + // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) + // response provided by the peer for the leaf certificate, if any. + OCSPResponse []byte - // TLSUnique contains the "tls-unique" channel binding value (see RFC - // 5929, section 3). For resumed sessions this value will be nil - // because resumption does not include enough context (see - // https://mitls.org/pages/attacks/3SHAKE#channelbindings). This will - // change in future versions of Go once the TLS master-secret fix has - // been standardized and implemented. It is not defined in TLS 1.3. + // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, + // Section 3). This value will be nil for TLS 1.3 connections and for all + // resumed connections. + // + // Deprecated: there are conditions in which this value might not be unique + // to a connection. See the Security Considerations sections of RFC 5705 and + // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. TLSUnique []byte + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) } // ExportKeyingMaterial returns length bytes of exported key material in a new @@ -250,10 +295,26 @@ func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, le type ClientAuthType int const ( + // NoClientCert indicates that no client certificate should be requested + // during the handshake, and if any certificates are sent they will not + // be verified. NoClientCert ClientAuthType = iota + // RequestClientCert indicates that a client certificate should be requested + // during the handshake, but does not require that the client send any + // certificates. RequestClientCert + // RequireAnyClientCert indicates that a client certificate should be requested + // during the handshake, and that at least one certificate is required to be + // sent by the client, but that certificate is not required to be valid. RequireAnyClientCert + // VerifyClientCertIfGiven indicates that a client certificate should be requested + // during the handshake, but does not require that the client sends a + // certificate. If the client does send a certificate it is required to be + // valid. VerifyClientCertIfGiven + // RequireAndVerifyClientCert indicates that a client certificate should be requested + // during the handshake, and that at least one valid certificate is required + // to be sent by the client. RequireAndVerifyClientCert ) @@ -379,11 +440,21 @@ type ClientHelloInfo struct { // Conn is the underlying net.Conn for the connection. Do not read // from, or write to, this connection; that will cause the TLS // connection to fail. - Conn conn + Conn net.Conn // config is embedded by the GetCertificate or GetConfigForClient caller, // for use with SupportsCertificate. config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *ClientHelloInfo) Context() context.Context { + return c.ctx } // CertificateRequestInfo contains information from a server's @@ -402,6 +473,16 @@ type CertificateRequestInfo struct { // Version is the TLS version that was negotiated for this connection. Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *CertificateRequestInfo) Context() context.Context { + return c.ctx } // RenegotiationSupport enumerates the different levels of support for TLS @@ -477,6 +558,8 @@ type Config struct { // If GetCertificate is nil or returns nil, then the certificate is // retrieved from NameToCertificate. If NameToCertificate is nil, the // best element of Certificates will be used. + // + // Once a Certificate is returned it should not be modified. GetCertificate func(*ClientHelloInfo) (*Certificate, error) // GetClientCertificate, if not nil, is called when a server requests a @@ -492,6 +575,8 @@ type Config struct { // // GetClientCertificate may be called multiple times for the same // connection if renegotiation occurs or if TLS 1.3 is in use. + // + // Once a Certificate is returned it should not be modified. GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) // GetConfigForClient, if not nil, is called after a ClientHello is @@ -520,7 +605,10 @@ type Config struct { // setting InsecureSkipVerify, or (for a server) when ClientAuth is // RequestClientCert or RequireAnyClientCert, then this callback will // be considered but the verifiedChains argument will always be nil. + // + // verifiedChains and its contents should not be modified. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + // VerifyConnection, if not nil, is called after normal certificate // verification and after VerifyPeerCertificate by either a TLS client // or server. If it returns a non-nil error, the handshake is aborted @@ -537,7 +625,11 @@ type Config struct { RootCAs *x509.CertPool // NextProtos is a list of supported application level protocols, in - // order of preference. + // order of preference. If both peers support ALPN, the selected + // protocol will be one from this list, and the connection will fail + // if there is no mutually supported protocol. If NextProtos is empty + // or the peer doesn't support ALPN, the connection will succeed and + // ConnectionState.NegotiatedProtocol will be empty. NextProtos []string // ServerName is used to verify the hostname on the returned @@ -555,25 +647,29 @@ type Config struct { // by the policy in ClientAuth. ClientCAs *x509.CertPool - // InsecureSkipVerify controls whether a client verifies the - // server's certificate chain and host name. - // If InsecureSkipVerify is true, TLS accepts any certificate - // presented by the server and any host name in that certificate. - // In this mode, TLS is susceptible to man-in-the-middle attacks. - // This should be used only for testing. + // InsecureSkipVerify controls whether a client verifies the server's + // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls + // accepts any certificate presented by the server and any host name in that + // certificate. In this mode, TLS is susceptible to machine-in-the-middle + // attacks unless custom verification is used. This should be used only for + // testing or in combination with VerifyConnection or VerifyPeerCertificate. InsecureSkipVerify bool - // CipherSuites is a list of supported cipher suites for TLS versions up to - // TLS 1.2. If CipherSuites is nil, a default list of secure cipher suites - // is used, with a preference order based on hardware performance. The - // default cipher suites might change over Go versions. Note that TLS 1.3 - // ciphersuites are not configurable. + // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of + // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. + // + // If CipherSuites is nil, a safe default list is used. The default cipher + // suites might change over time. CipherSuites []uint16 - // PreferServerCipherSuites controls whether the server selects the - // client's most preferred ciphersuite, or the server's most preferred - // ciphersuite. If true then the server's preference, as expressed in - // the order of elements in CipherSuites, is used. + // PreferServerCipherSuites is a legacy field and has no effect. + // + // It used to control whether the server would follow the client's or the + // server's preference. Servers now select the best mutually supported + // cipher suite based on logic that takes into account inferred client + // hardware, server hardware, and security. + // + // Deprecated: PreferServerCipherSuites is ignored. PreferServerCipherSuites bool // SessionTicketsDisabled may be set to true to disable session ticket and @@ -596,11 +692,20 @@ type Config struct { ClientSessionCache ClientSessionCache // MinVersion contains the minimum TLS version that is acceptable. - // If zero, TLS 1.0 is currently taken as the minimum. + // + // By default, TLS 1.2 is currently used as the minimum when acting as a + // client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum + // supported by this package, both as a client and as a server. + // + // The client-side default can temporarily be reverted to TLS 1.0 by + // including the value "x509sha1=1" in the GODEBUG environment variable. + // Note that this option will be removed in Go 1.19 (but it will still be + // possible to set this field to VersionTLS10 explicitly). MinVersion uint16 // MaxVersion contains the maximum TLS version that is acceptable. - // If zero, the maximum version supported by this package is used, + // + // By default, the maximum version supported by this package is used, // which is currently TLS 1.3. MaxVersion uint16 @@ -630,7 +735,7 @@ type Config struct { // mutex protects sessionTicketKeys and autoSessionTicketKeys. mutex sync.RWMutex - // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // sessionTicketKeys contains zero or more ticket keys. If set, it means // the keys were set with SessionTicketKey or SetSessionTicketKeys. The // first key is used for new tickets and any subsequent keys can be used to // decrypt old tickets. The slice contents are not protected by the mutex @@ -682,12 +787,14 @@ func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { // ticket, and the lifetime we set for tickets we send. const maxSessionTicketLifetime = 7 * 24 * time.Hour -// Clone returns a shallow clone of c. It is safe to clone a Config that is +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is // being used concurrently by a TLS client or server. func (c *Config) Clone() *Config { + if c == nil { + return nil + } c.mutex.RLock() defer c.mutex.RUnlock() - return &Config{ Rand: c.Rand, Time: c.Time, @@ -863,23 +970,37 @@ func (c *Config) time() time.Time { } func (c *Config) cipherSuites() []uint16 { - s := c.CipherSuites - if s == nil { - s = defaultCipherSuites() + if needFIPS() { + return fipsCipherSuites(c) + } + if c.CipherSuites != nil { + return c.CipherSuites } - return s + return defaultCipherSuites } var supportedVersions = []uint16{ VersionTLS13, VersionTLS12, VersionTLS11, - //VersionTLS10, + VersionTLS10, } -func (c *Config) supportedVersions() []uint16 { +// roleClient and roleServer are meant to call supportedVersions and parents +// with more readability at the callsite. +const roleClient = true +const roleServer = false + +func (c *Config) supportedVersions(isClient bool) []uint16 { versions := make([]uint16, 0, len(supportedVersions)) for _, v := range supportedVersions { + if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) { + continue + } + if (c == nil || c.MinVersion == 0) && + isClient && v < VersionTLS12 { + continue + } if c != nil && c.MinVersion != 0 && v < c.MinVersion { continue } @@ -891,8 +1012,8 @@ func (c *Config) supportedVersions() []uint16 { return versions } -func (c *Config) maxSupportedVersion() uint16 { - supportedVersions := c.supportedVersions() +func (c *Config) maxSupportedVersion(isClient bool) uint16 { + supportedVersions := c.supportedVersions(isClient) if len(supportedVersions) == 0 { return 0 } @@ -916,6 +1037,9 @@ func supportedVersionsFromMax(maxVersion uint16) []uint16 { var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} func (c *Config) curvePreferences() []CurveID { + if needFIPS() { + return fipsCurvePreferences(c) + } if c == nil || len(c.CurvePreferences) == 0 { return defaultCurvePreferences } @@ -933,8 +1057,8 @@ func (c *Config) supportsCurve(curve CurveID) bool { // mutualVersion returns the protocol version to use given the advertised // versions of the peer. Priority is given to the peer preference order. -func (c *Config) mutualVersion(peerVersions []uint16) (uint16, bool) { - supportedVersions := c.supportedVersions() +func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions(isClient) for _, peerVersion := range peerVersions { for _, v := range supportedVersions { if v == peerVersion { @@ -1013,7 +1137,7 @@ func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error { if config == nil { config = &Config{} } - vers, ok := config.mutualVersion(chi.SupportedVersions) + vers, ok := config.mutualVersion(roleServer, chi.SupportedVersions) if !ok { return errors.New("no mutually supported protocol versions") } @@ -1200,7 +1324,9 @@ func (c *Config) BuildNameToCertificate() { if err != nil { continue } - if len(x509Cert.Subject.CommonName) > 0 { + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { c.NameToCertificate[x509Cert.Subject.CommonName] = cert } for _, san := range x509Cert.DNSNames { @@ -1222,7 +1348,7 @@ func (c *Config) writeKeyLog(label string, clientRandom, secret []byte) error { return nil } - logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + logLine := fmt.Appendf(nil, "%s %x %x\n", label, clientRandom, secret) writerMutex.Lock() _, err := c.KeyLogWriter.Write(logLine) @@ -1355,88 +1481,7 @@ func defaultConfig() *Config { return &emptyConfig } -var ( - once sync.Once - varDefaultCipherSuites []uint16 - varDefaultCipherSuitesTLS13 []uint16 -) - -func defaultCipherSuites() []uint16 { - once.Do(initDefaultCipherSuites) - return varDefaultCipherSuites -} - -func defaultCipherSuitesTLS13() []uint16 { - once.Do(initDefaultCipherSuites) - return varDefaultCipherSuitesTLS13 -} - -func initDefaultCipherSuites() { - var topCipherSuites []uint16 - - // Check the cpu flags for each platform that has optimized GCM implementations. - // Worst case, these variables will just all be false. - var ( - hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ - hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - - hasGCMAsm = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X - ) - - if hasGCMAsm { - // If AES-GCM hardware is provided then prioritise AES-GCM - // cipher suites. - topCipherSuites = []uint16{ - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - } - varDefaultCipherSuitesTLS13 = []uint16{ - TLS_AES_128_GCM_SHA256, - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_256_GCM_SHA384, - } - } else { - // Without AES-GCM hardware, we put the ChaCha20-Poly1305 - // cipher suites first. - topCipherSuites = []uint16{ - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - } - varDefaultCipherSuitesTLS13 = []uint16{ - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - } - } - - varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) - varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) - -NextCipherSuite: - for _, suite := range cipherSuites { - if suite.flags&suiteDefaultOff != 0 { - continue - } - for _, existing := range varDefaultCipherSuites { - if existing == suite.id { - continue NextCipherSuite - } - } - varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) - } -} - -func unexpectedMessageError(wanted, got interface{}) error { +func unexpectedMessageError(wanted, got any) error { return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) } @@ -1448,3 +1493,18 @@ func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlg } return false } + +// CertificateVerificationError is returned when certificate verification fails during the handshake. +type CertificateVerificationError struct { + // UnverifiedCertificates and its contents should not be modified. + UnverifiedCertificates []*x509.Certificate + Err error +} + +func (e *CertificateVerificationError) Error() string { + return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err) +} + +func (e *CertificateVerificationError) Unwrap() error { + return e.Err +} diff --git a/pkg/tls/common_string.go b/pkg/tls/common_string.go new file mode 100644 index 000000000..238108811 --- /dev/null +++ b/pkg/tls/common_string.go @@ -0,0 +1,116 @@ +// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT. + +package tls + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[PKCS1WithSHA256-1025] + _ = x[PKCS1WithSHA384-1281] + _ = x[PKCS1WithSHA512-1537] + _ = x[PSSWithSHA256-2052] + _ = x[PSSWithSHA384-2053] + _ = x[PSSWithSHA512-2054] + _ = x[ECDSAWithP256AndSHA256-1027] + _ = x[ECDSAWithP384AndSHA384-1283] + _ = x[ECDSAWithP521AndSHA512-1539] + _ = x[Ed25519-2055] + _ = x[PKCS1WithSHA1-513] + _ = x[ECDSAWithSHA1-515] +} + +const ( + _SignatureScheme_name_0 = "PKCS1WithSHA1" + _SignatureScheme_name_1 = "ECDSAWithSHA1" + _SignatureScheme_name_2 = "PKCS1WithSHA256" + _SignatureScheme_name_3 = "ECDSAWithP256AndSHA256" + _SignatureScheme_name_4 = "PKCS1WithSHA384" + _SignatureScheme_name_5 = "ECDSAWithP384AndSHA384" + _SignatureScheme_name_6 = "PKCS1WithSHA512" + _SignatureScheme_name_7 = "ECDSAWithP521AndSHA512" + _SignatureScheme_name_8 = "PSSWithSHA256PSSWithSHA384PSSWithSHA512Ed25519" +) + +var ( + _SignatureScheme_index_8 = [...]uint8{0, 13, 26, 39, 46} +) + +func (i SignatureScheme) String() string { + switch { + case i == 513: + return _SignatureScheme_name_0 + case i == 515: + return _SignatureScheme_name_1 + case i == 1025: + return _SignatureScheme_name_2 + case i == 1027: + return _SignatureScheme_name_3 + case i == 1281: + return _SignatureScheme_name_4 + case i == 1283: + return _SignatureScheme_name_5 + case i == 1537: + return _SignatureScheme_name_6 + case i == 1539: + return _SignatureScheme_name_7 + case 2052 <= i && i <= 2055: + i -= 2052 + return _SignatureScheme_name_8[_SignatureScheme_index_8[i]:_SignatureScheme_index_8[i+1]] + default: + return "SignatureScheme(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[CurveP256-23] + _ = x[CurveP384-24] + _ = x[CurveP521-25] + _ = x[X25519-29] +} + +const ( + _CurveID_name_0 = "CurveP256CurveP384CurveP521" + _CurveID_name_1 = "X25519" +) + +var ( + _CurveID_index_0 = [...]uint8{0, 9, 18, 27} +) + +func (i CurveID) String() string { + switch { + case 23 <= i && i <= 25: + i -= 23 + return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]] + case i == 29: + return _CurveID_name_1 + default: + return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[NoClientCert-0] + _ = x[RequestClientCert-1] + _ = x[RequireAnyClientCert-2] + _ = x[VerifyClientCertIfGiven-3] + _ = x[RequireAndVerifyClientCert-4] +} + +const _ClientAuthType_name = "NoClientCertRequestClientCertRequireAnyClientCertVerifyClientCertIfGivenRequireAndVerifyClientCert" + +var _ClientAuthType_index = [...]uint8{0, 12, 29, 49, 72, 98} + +func (i ClientAuthType) String() string { + if i < 0 || i >= ClientAuthType(len(_ClientAuthType_index)-1) { + return "ClientAuthType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ClientAuthType_name[_ClientAuthType_index[i]:_ClientAuthType_index[i+1]] +} diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 5c588499f..fb7af6969 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -7,14 +7,17 @@ package tls import ( + "context" "crypto/cipher" "crypto/subtle" "crypto/x509" "errors" "fmt" + "hash" "io" "net" "sync" + "time" "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" ) @@ -23,8 +26,10 @@ import ( // It implements the net.Conn interface. type Conn struct { // constant - conn conn - isClient bool + conn net.Conn + isClient bool + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). // This field is only to be accessed with sync/atomic. @@ -44,6 +49,9 @@ type Conn struct { ocspResponse []byte // stapled OCSP response scts [][]byte // signed certificate timestamps from server peerCertificates []*x509.Certificate + // activeCertHandles contains the cache handles to certificates in + // peerCertificates that are used to track active references. + activeCertHandles []*activeCert // verifiedChains contains the certificate chains that we built, as // opposed to the ones presented by the server. verifiedChains [][]*x509.Certificate @@ -83,19 +91,16 @@ type Conn struct { clientFinished [12]byte serverFinished [12]byte - clientProtocol string - clientProtocolFallback bool + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string // input/output - in, out halfConn - rawInput MsgBuffer // raw input, starting with a record header - input *elastic.RingBuffer // a buffer for decrypted records - // pointer to the inboundBuffer of gnet.conn - hand MsgBuffer // handshake data waiting to be read - outBuf []byte // scratch buffer used by out.encrypt - buffering bool // whether records are buffered in sendBuf - sendBuf *elastic.Buffer // a buffer for records waiting to be sent - // also point to the outboundBuffer of gnet.conn + in, out halfConn + rawInput MsgBuffer // raw input, starting with a record header + input *elastic.RingBuffer // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn + hand MsgBuffer // handshake data waiting to be read + // buffering bool // whether records are buffered in sendBuf + sendBuf *elastic.Buffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn // bytesSent counts the bytes of application data sent. // packetsSent counts packets. @@ -107,42 +112,88 @@ type Conn struct { // handshake, nor deliver application data. Protected by in.Mutex. retryCount int + // activeCall indicates whether Close has been call in the low bit. + // the rest of the bits are the number of goroutines in Conn.Write. + // activeCall atomic.Int32 + tmp [16]byte hs interface { handshake() error } } +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// TLS session. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + // A halfConn represents one direction of the record layer // connection, either sending or receiving. type halfConn struct { sync.Mutex - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac macFunction - seq [8]byte // 64-bit sequence number - additionalData [13]byte // to avoid allocs; interface method args escape + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape nextCipher interface{} // next encryption state - nextMac macFunction // next MAC algorithm + nextMac hash.Hash // next MAC algorithm trafficSecret []byte // current TLS 1.3 traffic secret } -type permamentError struct { +type permanentError struct { err net.Error } -func (e *permamentError) Error() string { return e.err.Error() } -func (e *permamentError) Unwrap() error { return e.err } -func (e *permamentError) Timeout() bool { return e.err.Timeout() } -func (e *permamentError) Temporary() bool { return false } +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } func (hc *halfConn) setErrorLocked(err error) error { if e, ok := err.(net.Error); ok { - hc.err = &permamentError{err: e} + hc.err = &permanentError{err: e} } else { hc.err = err } @@ -151,7 +202,7 @@ func (hc *halfConn) setErrorLocked(err error) error { // prepareCipherSpec sets the encryption and MAC states // that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { hc.version = version hc.nextCipher = cipher hc.nextMac = mac @@ -313,15 +364,14 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { } payload = payload[explicitNonceLen:] - additionalData := hc.additionalData[:] + var additionalData []byte if hc.version == VersionTLS13 { additionalData = record[:recordHeaderLen] } else { - copy(additionalData, hc.seq[:]) - copy(additionalData[8:], record[:3]) + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) n := len(payload) - c.Overhead() - additionalData[11] = byte(n >> 8) - additionalData[12] = byte(n) + additionalData = append(additionalData, byte(n>>8), byte(n)) } var err error @@ -387,7 +437,7 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { record[3] = byte(n >> 8) record[4] = byte(n) remoteMAC := payload[n : n+macSize] - localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) // This is equivalent to checking the MACs and paddingGood // separately, but in constant-time to prevent distinguishing @@ -407,6 +457,10 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { hc.incSeq() return plaintext, typ, nil } + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. func sliceForAppend(in []byte, n int) (head, tail []byte) { if total := len(in) + n; cap(in) >= total { head = in[:total] @@ -419,7 +473,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) { } // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which contains the record header. +// appends it to record, which must already contain the record header. func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { if hc.cipher == nil { return append(record, payload...), nil @@ -436,7 +490,7 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err // an 8 bytes nonce but its nonces must be unpredictable (see RFC // 5246, Appendix F.3), forcing us to use randomness. That's not // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its simlarly small block size + // collision is reached first due to its similarly small block size // (see the Sweet32 attack). copy(explicitNonce, hc.seq[:]) } else { @@ -446,14 +500,10 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err } } - var mac []byte - if hc.mac != nil { - mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil) - } - var dst []byte switch c := hc.cipher.(type) { case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) record, dst = sliceForAppend(record, len(payload)+len(mac)) c.XORKeyStream(dst[:len(payload)], payload) c.XORKeyStream(dst[len(payload):], mac) @@ -477,11 +527,12 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err record = c.Seal(record[:recordHeaderLen], nonce, record[recordHeaderLen:], record[:recordHeaderLen]) } else { - copy(hc.additionalData[:], hc.seq[:]) - copy(hc.additionalData[8:], record) - record = c.Seal(record, nonce, payload, hc.additionalData[:]) + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) } case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) blockSize := c.BlockSize() plaintextLen := len(payload) + len(mac) paddingLen := blockSize - plaintextLen%blockSize @@ -519,12 +570,12 @@ type RecordHeaderError struct { // sent an initial handshake that didn't look like TLS. // It is nil if there's already been a handshake or a TLS alert has // been written to the connection. - Conn conn + Conn net.Conn } func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } -func (c *Conn) newRecordHeaderError(conn conn, msg string) (err RecordHeaderError) { +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { err.Msg = msg err.Conn = conn copy(err.RecordHeader[:], c.rawInput.Bytes()) @@ -532,7 +583,7 @@ func (c *Conn) newRecordHeaderError(conn conn, msg string) (err RecordHeaderErro } func (c *Conn) readRecord() error { - if c.rawInput.Len() > 5 { + if c.rawInput.Len() > recordHeaderLen { return c.readRecordOrCCS(false) } return io.EOF @@ -540,17 +591,22 @@ func (c *Conn) readRecord() error { func (c *Conn) readChangeCipherSpec() error { c.input.Reset() - return c.readRecordOrCCS(true) + if c.rawInput.Len() > recordHeaderLen { + return c.readRecordOrCCS(true) + } + return io.EOF } // readRecordOrCCS reads one or more TLS records from the connection and // updates the record layer state. Some invariants: -// * c.in must be locked -// * c.input must be empty +// - c.in must be locked +// - c.input must be empty +// // During the handshake one and only one of the following will happen: // - c.hand grows // - c.in.changeCipherSpec is called // - an error is returned +// // After the handshake one and only one of the following will happen: // - c.hand grows // - c.input is set @@ -559,6 +615,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { if c.in.err != nil { return c.in.err } + handshakeComplete := c.HandshakeComplete() hdr := c.rawInput.Bytes() typ := recordType(hdr[0]) @@ -567,24 +624,24 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { // start with a uint16 length where the MSB is set and the first record // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests // an SSLv2 client. + if !handshakeComplete && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } vers := uint16(hdr[1])<<8 | uint16(hdr[2]) n := int(hdr[3])<<8 | int(hdr[4]) + if len(hdr) < recordHeaderLen+n { return io.EOF } - // Read header, payload. - if c.handshakeStatus != 255 && typ == 0x80 { + + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) } - if c.haveVers { - if c.vers != VersionTLS13 && vers != c.vers { - c.sendAlert(alertProtocolVersion) - msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - } else { + if !c.haveVers { // First message, be extra suspicious: this might not be a TLS // client. Bail out before reading a full 'body', if possible. // The current max version is 3.3 so if the version is >= 16.0, @@ -602,7 +659,6 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { // Process message. c.rawInput.Shift(recordHeaderLen + n) data, typ, err := c.in.decrypt(hdr[:recordHeaderLen+n]) - if err != nil { return c.in.setErrorLocked(c.sendAlert(err.(alert))) } @@ -673,7 +729,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { } case recordTypeApplicationData: - if c.handshakeStatus != 255 || expectChangeCipherSpec { + if !handshakeComplete || expectChangeCipherSpec { return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } // Some OpenSSL servers send empty records in order to randomize the @@ -696,7 +752,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { return nil } -// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like // a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { c.retryCount++ @@ -705,7 +761,7 @@ func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { return c.in.setErrorLocked(errors.New("tls: too many ignored records")) } c.input.Reset() - if c.rawInput.Len() > 5 { + if c.rawInput.Len() > recordHeaderLen { return c.readRecordOrCCS(expectChangeCipherSpec) } return io.EOF @@ -732,7 +788,6 @@ func (c *Conn) sendAlertLocked(err alert) error { // sendAlert sends a TLS alert message. func (c *Conn) sendAlert(err alert) error { - return c.sendAlertLocked(err) } @@ -813,12 +868,12 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { return n } -func (c *Conn) write(data []byte) (n int, err error) { +func (c *Conn) write(data []byte) (int, error) { //必须把所有数据往buf写 - n = len(data) + n := len(data) c.sendBuf.Write(data) c.bytesSent += int64(n) - return + return n, nil } func (c *Conn) flush() (int, error) { @@ -827,23 +882,42 @@ func (c *Conn) flush() (int, error) { } n, err := c.conn.Write(nil) c.bytesSent += int64(n) - c.buffering = false + // c.buffering = false return n, err } +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. -func (c *Conn) writeRecordLocked(typ recordType, data []byte) (n int, err error) { - +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + + var n int for len(data) > 0 { m := len(data) if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { m = maxPayload } - _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen) - c.outBuf[0] = byte(typ) - /*vers := c.vers + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers if vers == 0 { // Some TLS servers fail if the record version is // greater than TLS 1.0 for the initial ClientHello. @@ -853,16 +927,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (n int, err error) // See RFC 8446, Section 5.1. vers = VersionTLS12 } - c.outBuf[1] = byte(vers >> 8) - c.outBuf[2] = byte(vers)*/ - c.outBuf[3] = byte(m >> 8) - c.outBuf[4] = byte(m) + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) - c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand()) + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) if err != nil { return n, err } - if _, err = c.write(c.outBuf); err != nil { + if _, err := c.write(outBuf); err != nil { return n, err } n += m @@ -870,18 +945,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (n int, err error) } if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { - if err = c.out.changeCipherSpec(); err != nil { + if err := c.out.changeCipherSpec(); err != nil { return n, c.sendAlertLocked(err.(alert)) } } - return + return n, nil } // writeRecord writes a TLS record with the given type and payload to the // connection and updates the record layer state. func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { - return c.writeRecordLocked(typ, data) } @@ -970,26 +1044,30 @@ func (c *Conn) readHandshake() (interface{}, error) { } var ( - errClosed = errors.New("tls: use of closed connection") errShutdown = errors.New("tls: protocol is shutdown") ) // Write writes data to the connection. -func (c *Conn) Write(b []byte) error { +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Write(b []byte) (int, error) { // interlock with Close below - if c.handshakeStatus != 255 { - return nil + if !c.HandshakeComplete() { + return 0, nil } - c.buffering = false + // c.buffering = false if err := c.out.err; err != nil { - return err + return 0, err } if c.closeNotifySent { - return errShutdown + return 0, errShutdown } // TLS 1.0 is susceptible to a chosen-plaintext @@ -1001,8 +1079,19 @@ func (c *Conn) Write(b []byte) error { // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 // https://www.imperialviolet.org/2012/01/15/beastfollowup.html - _, err := c.writeRecordLocked(recordTypeApplicationData, b) - return c.out.setErrorLocked(err) + var m int + if len(b) > 1 && c.vers == VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) } // load the data into the TLS rawInput @@ -1014,7 +1103,7 @@ func (c *Conn) RawWrite(data []byte) (int, error) { // Decrypt one tls record and save it in the 解析一条tls数据 func (c *Conn) ReadFrame() error { - if c.rawInput.Len() > 5 { + if c.rawInput.Len() > recordHeaderLen { return c.readRecordOrCCS(false) } return io.EOF @@ -1024,112 +1113,8 @@ func (c *Conn) RawData() []byte { return c.rawInput.Bytes() } -// Close closes the connection. -var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") - -// CloseWrite shuts down the writing side of the connection. It should only be -// called once the handshake has completed and does not call CloseWrite on the -// underlying connection. Most callers should just use Close. -func (c *Conn) CloseWrite() error { - if c.handshakeStatus != 255 { - return errEarlyCloseWrite - } - - return c.closeNotify() -} - -func (c *Conn) closeNotify() error { - if !c.closeNotifySent { - c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) - c.closeNotifySent = true - } - return c.closeNotifyErr -} - -// Handshake runs the client or server handshake -// protocol if it has not yet been run. -// Most uses of this package need not call Handshake -// explicitly: the first Read or Write will call it automatically. -func (c *Conn) Handshake() error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if err := c.handshakeErr; err != nil { - return err - } - if c.handshakeStatus == 255 { - return nil - } - - if c.isClient { - c.handshakeErr = c.clientHandshake() - } else { - c.handshakeErr = c.serverHandshake() - } - - if c.handshakeErr == io.EOF { - c.handshakeErr = nil - } - if c.handshakeErr == nil { - c.handshakes++ - } else { - //panic(c.handshakeErr) - // If an error occurred during the handshake try to flush the - // alert that might be left in the buffer. - c.flush() - } - - return c.handshakeErr -} - -// ConnectionState returns basic TLS details about the connection. -func (c *Conn) ConnectionState() ConnectionState { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - var state ConnectionState - state.HandshakeComplete = c.handshakeStatus == 255 - state.ServerName = c.serverName - - if state.HandshakeComplete { - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial - } else { - state.ekm = c.ekm - } - } - - return state -} - -// OCSPResponse returns the stapled OCSP response from the TLS server, if -// any. (Only valid for client connections.) -func (c *Conn) OCSPResponse() []byte { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - return c.ocspResponse -} -// VerifyHostname checks that the peer certificate chain is valid for -// connecting to host. If so, it returns nil; if not, it returns an error -// describing the problem. +// handleRenegotiation processes a HelloRequest handshake message. func (c *Conn) handleRenegotiation() error { if c.vers == VersionTLS13 { return errors.New("tls: internal error: unexpected renegotiation") @@ -1168,7 +1153,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() c.handshakeStatus = 0 - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1228,13 +1213,153 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { return nil } + +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Read(b []byte) (int, error) { + if !c.HandshakeComplete() { + return 0, nil + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// Close closes the connection. +func (c *Conn) Close() error { + var alertErr error + if c.HandshakeComplete() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if !c.HandshakeComplete() { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + if !c.closeNotifySent { + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// HandshakeContext or the Dialer's DialContext method instead. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.HandshakeComplete() { + return nil + } + + c.handshakeErr = c.handshakeFn(ctx) + if c.handshakeErr == io.EOF { + c.handshakeErr = nil + } + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} + func (c *Conn) connectionStateLocked() ConnectionState { var state ConnectionState - state.HandshakeComplete = c.handshakeStatus == 255 + state.HandshakeComplete = c.HandshakeComplete() state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.NegotiatedProtocolIsMutual = true state.ServerName = c.serverName state.CipherSuite = c.cipherSuite state.PeerCertificates = c.peerCertificates @@ -1255,6 +1380,34 @@ func (c *Conn) connectionStateLocked() ConnectionState { } return state } + func (c *Conn) HandshakeComplete() bool { return c.handshakeStatus == 255 } + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.HandshakeComplete() { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} diff --git a/pkg/tls/generate_cert.go b/pkg/tls/generate_cert.go new file mode 100644 index 000000000..cd4bfc513 --- /dev/null +++ b/pkg/tls/generate_cert.go @@ -0,0 +1,171 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore + +// Generate a self-signed X.509 certificate for a TLS server. Outputs to +// 'cert.pem' and 'key.pem' and will overwrite existing files. + +package main + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "flag" + "log" + "math/big" + "net" + "os" + "strings" + "time" +) + +var ( + host = flag.String("host", "", "Comma-separated hostnames and IPs to generate a certificate for") + validFrom = flag.String("start-date", "", "Creation date formatted as Jan 1 15:04:05 2011") + validFor = flag.Duration("duration", 365*24*time.Hour, "Duration that certificate is valid for") + isCA = flag.Bool("ca", false, "whether this cert should be its own Certificate Authority") + rsaBits = flag.Int("rsa-bits", 2048, "Size of RSA key to generate. Ignored if --ecdsa-curve is set") + ecdsaCurve = flag.String("ecdsa-curve", "", "ECDSA curve to use to generate a key. Valid values are P224, P256 (recommended), P384, P521") + ed25519Key = flag.Bool("ed25519", false, "Generate an Ed25519 key") +) + +func publicKey(priv any) any { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + case ed25519.PrivateKey: + return k.Public().(ed25519.PublicKey) + default: + return nil + } +} + +func main() { + flag.Parse() + + if len(*host) == 0 { + log.Fatalf("Missing required --host parameter") + } + + var priv any + var err error + switch *ecdsaCurve { + case "": + if *ed25519Key { + _, priv, err = ed25519.GenerateKey(rand.Reader) + } else { + priv, err = rsa.GenerateKey(rand.Reader, *rsaBits) + } + case "P224": + priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + case "P256": + priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + case "P384": + priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + case "P521": + priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + default: + log.Fatalf("Unrecognized elliptic curve: %q", *ecdsaCurve) + } + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + // ECDSA, ED25519 and RSA subject keys should have the DigitalSignature + // KeyUsage bits set in the x509.Certificate template + keyUsage := x509.KeyUsageDigitalSignature + // Only RSA subject keys should have the KeyEncipherment KeyUsage bits set. In + // the context of TLS this KeyUsage is particular to RSA key exchange and + // authentication. + if _, isRSA := priv.(*rsa.PrivateKey); isRSA { + keyUsage |= x509.KeyUsageKeyEncipherment + } + + var notBefore time.Time + if len(*validFrom) == 0 { + notBefore = time.Now() + } else { + notBefore, err = time.Parse("Jan 2 15:04:05 2006", *validFrom) + if err != nil { + log.Fatalf("Failed to parse creation date: %v", err) + } + } + + notAfter := notBefore.Add(*validFor) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("Failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + hosts := strings.Split(*host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + if *isCA { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) + if err != nil { + log.Fatalf("Failed to create certificate: %v", err) + } + + certOut, err := os.Create("cert.pem") + if err != nil { + log.Fatalf("Failed to open cert.pem for writing: %v", err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + log.Fatalf("Failed to write data to cert.pem: %v", err) + } + if err := certOut.Close(); err != nil { + log.Fatalf("Error closing cert.pem: %v", err) + } + log.Print("wrote cert.pem\n") + + keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + log.Fatalf("Failed to open key.pem for writing: %v", err) + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + log.Fatalf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + log.Fatalf("Failed to write data to key.pem: %v", err) + } + if err := keyOut.Close(); err != nil { + log.Fatalf("Error closing key.pem: %v", err) + } + log.Print("wrote key.pem\n") +} diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go index f3f3d043b..669ab68cc 100644 --- a/pkg/tls/handshake_client.go +++ b/pkg/tls/handshake_client.go @@ -6,7 +6,9 @@ package tls import ( "bytes" + "context" "crypto" + "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -14,6 +16,7 @@ import ( "crypto/x509" "errors" "fmt" + "hash" "io" "net" "strconv" @@ -23,6 +26,7 @@ import ( type clientHandshakeState struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg suite *cipherSuite @@ -33,7 +37,9 @@ type clientHandshakeState struct { cacheKey string } -func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { +var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme + +func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { config := c.config if len(config.ServerName) == 0 && !config.InsecureSkipVerify { return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") @@ -51,12 +57,12 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return nil, nil, errors.New("tls: NextProtos values too large") } - supportedVersions := config.supportedVersions() + supportedVersions := config.supportedVersions(roleClient) if len(supportedVersions) == 0 { return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") } - clientHelloVersion := config.maxSupportedVersion() + clientHelloVersion := config.maxSupportedVersion(roleClient) // The version at the beginning of the ClientHello was capped at TLS 1.2 // for compatibility reasons. The supported_versions extension is used // to negotiate versions now. See RFC 8446, Section 4.2.1. @@ -83,22 +89,24 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { hello.secureRenegotiation = c.clientFinished[:] } - possibleCipherSuites := config.cipherSuites() - hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + configCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) - for _, suiteId := range possibleCipherSuites { - for _, suite := range cipherSuites { - if suite.id != suiteId { - continue - } - // Don't advertise TLS 1.2-only cipher suites unless - // we're attempting TLS 1.2. - if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { - break - } - hello.cipherSuites = append(hello.cipherSuites, suiteId) - break + for _, suiteId := range preferenceOrder { + suite := mutualCipherSuite(configCipherSuites, suiteId) + if suite == nil { + continue } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + continue + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) } _, err := io.ReadFull(config.rand(), hello.random) @@ -114,28 +122,35 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { } if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + if testingOnlyForceClientHelloSignatureAlgorithms != nil { + hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } - var params ecdheParameters + var key *ecdh.PrivateKey if hello.supportedVersions[0] == VersionTLS13 { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...) + if hasAESGCMHardwareSupport { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) + } else { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) + } curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") } - params, err = generateECDHEParameters(config.rand(), curveID) + key, err = generateECDHEKey(config.rand(), curveID) if err != nil { return nil, nil, err } - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } - return hello, params, nil + return hello, key, nil } -func (c *Conn) clientHandshake() (err error) { +func (c *Conn) clientHandshake(ctx context.Context) (err error) { switch c.handshakeStatus { case 0: @@ -146,11 +161,11 @@ func (c *Conn) clientHandshake() (err error) { // need to be reset. c.didResume = false - hello, ecdheParams, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName + hello, ecdheKey, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName cacheKey, session, earlySecret, binderKey := c.loadSession(hello) if cacheKey != "" && session != nil { @@ -175,14 +190,16 @@ func (c *Conn) clientHandshake() (err error) { c.hs = &clientHandshakeStateTLS13{ //临时缓存 c: c, + ctx: ctx, hello: hello, - ecdheParams: ecdheParams, + ecdheKey: ecdheKey, session: session, earlySecret: earlySecret, binderKey: binderKey, cacheKey: cacheKey, } + return nil case 1: hello := c.hs.(*clientHandshakeStateTLS13).hello msg, err := c.readHandshake() @@ -200,29 +217,31 @@ func (c *Conn) clientHandshake() (err error) { return err } c.handshakeStatus = 2 - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion() - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion(roleClient) + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + hs13 := c.hs.(*clientHandshakeStateTLS13) if c.vers == VersionTLS13 { - c.hs.(*clientHandshakeStateTLS13).serverHello = serverHello + hs13.serverHello = serverHello // In TLS 1.3, session tickets are delivered after the handshake. return c.hs.handshake() } hs := &clientHandshakeState{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, - session: c.hs.(*clientHandshakeStateTLS13).session, - oldsession: c.hs.(*clientHandshakeStateTLS13).session, - cacheKey: c.hs.(*clientHandshakeStateTLS13).cacheKey, + session: hs13.session, + oldsession: hs13.session, + cacheKey: hs13.cacheKey, } c.hs = hs if err := hs.handshake(); err != nil { @@ -230,9 +249,11 @@ func (c *Conn) clientHandshake() (err error) { } case 3, 4, 5: c.hs.handshake() + default: return errors.New("tls handshakeStatus error:" + strconv.Itoa(int(c.handshakeStatus))) } + return nil } @@ -266,7 +287,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // Check that version used for the previous session is still valid. versOk := false - for _, v := range hello.supportedVersions { if v == session.vers { versOk = true @@ -359,7 +379,7 @@ func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { peerVersion = serverHello.supportedVersion } - vers, ok := c.config.mutualVersion([]uint16{peerVersion}) + vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion}) if !ok { c.sendAlert(alertProtocolVersion) return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) @@ -375,11 +395,13 @@ func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { // Does the handshake, either a full one or resumes old session. Requires hs.c, // hs.hello, hs.serverHello, and, optionally, hs.session to be set. -func (hs *clientHandshakeState) handshake() (err error) { +func (hs *clientHandshakeState) handshake() error { c := hs.c if c.handshakeStatus == 2 { - c.didResume, err = hs.processServerHello() + isResume, err := hs.processServerHello() + c.didResume = isResume + if err != nil { return err } @@ -490,7 +512,6 @@ func (hs *clientHandshakeState) doFullHandshakeStep1() error { certMsg, ok := msg.(*certificateMsg) if !ok || len(certMsg.certificates) == 0 { c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) } hs.finishedHash.Write(certMsg.marshal()) @@ -568,7 +589,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { certRequested = true hs.finishedHash.Write(certReq.marshal()) - cri := certificateRequestInfoFromMsg(c.vers, certReq) + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { c.sendAlert(alertInternalError) return err @@ -642,7 +663,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { } } - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) signOpts := crypto.SignerOpts(sigHash) if sigType == signatureRSAPSS { signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} @@ -676,12 +697,12 @@ func (hs *clientHandshakeState) establishKeys() error { clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) var clientCipher, serverCipher interface{} - var clientHash, serverHash macFunction + var clientHash, serverHash hash.Hash if hs.suite.cipher != nil { clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(c.vers, clientMAC) + clientHash = hs.suite.mac(clientMAC) serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(c.vers, serverMAC) + serverHash = hs.suite.mac(serverMAC) } else { clientCipher = hs.suite.aead(clientKey, clientIV) serverCipher = hs.suite.aead(serverKey, serverIV) @@ -729,18 +750,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { } } - clientDidALPN := len(hs.hello.alpnProtocols) > 0 - serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 - - if !clientDidALPN && serverHasALPN { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server advertised unrequested ALPN extension") + if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return false, err } + c.clientProtocol = hs.serverHello.alpnProtocol - if serverHasALPN { - c.clientProtocol = hs.serverHello.alpnProtocol - c.clientProtocolFallback = false - } c.scts = hs.serverHello.scts if !hs.serverResumedSession() { @@ -767,9 +782,27 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { if len(c.scts) == 0 && len(hs.session.scts) != 0 { c.scts = hs.session.scts } + return true, nil } +// checkALPN ensure that the server's choice of ALPN protocol is compatible with +// the protocols that we advertised in the Client Hello. +func checkALPN(clientProtos []string, serverProto string) error { + if serverProto == "" { + return nil + } + if len(clientProtos) == 0 { + return errors.New("tls: server advertised unrequested ALPN extension") + } + for _, proto := range clientProtos { + if proto == serverProto { + return nil + } + } + return errors.New("tls: server selected unadvertised ALPN protocol") +} + func (hs *clientHandshakeState) readFinished(out []byte) error { c := hs.c @@ -850,14 +883,16 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error { // verifyServerCertificate parses and verifies the provided chain, setting // c.verifiedChains and c.peerCertificates or sending the appropriate alert. func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + activeHandles := make([]*activeCert, len(certificates)) certs := make([]*x509.Certificate, len(certificates)) for i, asn1Data := range certificates { - cert, err := x509.ParseCertificate(asn1Data) + cert, err := clientCertCache.newCert(asn1Data) if err != nil { c.sendAlert(alertBadCertificate) return errors.New("tls: failed to parse certificate from server: " + err.Error()) } - certs[i] = cert + activeHandles[i] = cert + certs[i] = cert.cert } if !c.config.InsecureSkipVerify { @@ -867,6 +902,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { DNSName: c.config.ServerName, Intermediates: x509.NewCertPool(), } + for _, cert := range certs[1:] { opts.Intermediates.AddCert(cert) } @@ -874,12 +910,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { c.verifiedChains, err = certs[0].Verify(opts) if err != nil { c.sendAlert(alertBadCertificate) - return err + return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} } } - - switch certs[0].PublicKey.(type) { case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: break @@ -888,6 +922,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) } + c.activeCertHandles = activeHandles c.peerCertificates = certs if c.config.VerifyPeerCertificate != nil { @@ -909,10 +944,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { cri := &CertificateRequestInfo{ AcceptableCAs: certReq.certificateAuthorities, Version: vers, + ctx: ctx, } var rsaAvail, ecAvail bool @@ -997,22 +1033,6 @@ func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { return serverAddr.String() } -// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol -// given list of possible protocols and a list of the preference order. The -// first list must not be empty. It returns the resulting protocol and flag -// indicating if the fallback case was reached. -func mutualProtocol(protos, preferenceProtos []string) (string, bool) { - for _, s := range preferenceProtos { - for _, c := range protos { - if s == c { - return s, false - } - } - } - - return protos[0], true -} - // hostnameInSNI converts name into an appropriate hostname for SNI. // Literal IP addresses and absolute FQDNs are not permitted as SNI values. // See RFC 6066, Section 3. diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go index 5cf35aec6..2b272c1c4 100644 --- a/pkg/tls/handshake_client_tls13.go +++ b/pkg/tls/handshake_client_tls13.go @@ -6,7 +6,9 @@ package tls import ( "bytes" + "context" "crypto" + "crypto/ecdh" "crypto/hmac" "crypto/rsa" "errors" @@ -16,9 +18,10 @@ import ( type clientHandshakeStateTLS13 struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg - ecdheParams ecdheParameters + ecdheKey *ecdh.PrivateKey session *ClientSessionState earlySecret []byte @@ -34,11 +37,15 @@ type clientHandshakeStateTLS13 struct { cacheKey string } -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and, // optionally, hs.session, hs.earlySecret and hs.binderKey to be set. func (hs *clientHandshakeStateTLS13) handshake() error { c := hs.c + if needFIPS() { + return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") + } + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, // sections 4.1.2 and 4.1.3. if c.handshakes > 255 { @@ -46,7 +53,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 { return c.sendAlert(alertInternalError) } @@ -65,8 +72,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error { return err } } + hs.transcript.Write(hs.serverHello.marshal()) + //c.buffering = true if err := hs.processServerHello(); err != nil { return err } @@ -192,6 +201,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { if hs.serverHello.cookie != nil { hs.hello.cookie = hs.serverHello.cookie } + if hs.serverHello.serverShare.group != 0 { c.sendAlert(alertDecodeError) return errors.New("tls: received malformed key_share extension") @@ -212,21 +222,21 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } - if hs.ecdheParams.CurveID() == curveID { + if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID { c.sendAlert(alertIllegalParameter) return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(c.config.rand(), curveID) + key, err := generateECDHEKey(c.config.rand(), curveID) if err != nil { c.sendAlert(alertInternalError) return err } - hs.ecdheParams = params - hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + hs.ecdheKey = key + hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } hs.hello.raw = nil @@ -300,7 +310,7 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server did not send a key share") } - if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } @@ -338,8 +348,13 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c - sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) - if sharedKey == nil { + peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + sharedKey, err := hs.ecdheKey.ECDH(peerKey) + if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid server key share") } @@ -358,7 +373,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { serverHandshakeTrafficLabel, hs.transcript) c.in.setTrafficSecret(hs.suite, serverSecret) - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) if err != nil { c.sendAlert(alertInternalError) return err @@ -390,9 +405,9 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { } hs.transcript.Write(encryptedExtensions.marshal()) - if len(encryptedExtensions.alpnProtocol) != 0 && len(hs.hello.alpnProtocols) == 0 { + if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server advertised unrequested ALPN extension") + return err } c.clientProtocol = encryptedExtensions.alpnProtocol @@ -464,7 +479,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { } // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { c.sendAlert(alertIllegalParameter) return errors.New("tls: certificate used with invalid signature algorithm") } diff --git a/pkg/tls/handshake_messages.go b/pkg/tls/handshake_messages.go index b5f81e443..7ab0f100b 100644 --- a/pkg/tls/handshake_messages.go +++ b/pkg/tls/handshake_messages.go @@ -329,8 +329,7 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { m.pskBinders = pskBinders if m.raw != nil { lenWithoutBinders := len(m.marshalWithoutBinders()) - // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. - b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) + b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, binder := range m.pskBinders { b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { @@ -338,7 +337,7 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { }) } }) - if len(b.BytesOrPanic()) != len(m.raw) { + if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { panic("tls: internal error: failed to update binders") } } @@ -385,6 +384,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } + seenExts := make(map[uint16]bool) for !extensions.Empty() { var extension uint16 var extData cryptobyte.String @@ -393,6 +393,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } + if seenExts[extension] { + return false + } + seenExts[extension] = true + switch extension { case extensionServerName: // RFC 6066, Section 3 @@ -751,6 +756,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } + seenExts := make(map[uint16]bool) for !extensions.Empty() { var extension uint16 var extData cryptobyte.String @@ -759,6 +765,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } + if seenExts[extension] { + return false + } + seenExts[extension] = true + switch extension { case extensionStatusRequest: m.ocspStapling = true diff --git a/pkg/tls/handshake_messages_test.go b/pkg/tls/handshake_messages_test.go new file mode 100644 index 000000000..c6fc8f2bf --- /dev/null +++ b/pkg/tls/handshake_messages_test.go @@ -0,0 +1,486 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "encoding/hex" + "math/rand" + "reflect" + "strings" + "testing" + "testing/quick" + "time" +) + +var tests = []any{ + &clientHelloMsg{}, + &serverHelloMsg{}, + &finishedMsg{}, + + &certificateMsg{}, + &certificateRequestMsg{}, + &certificateVerifyMsg{ + hasSignatureAlgorithm: true, + }, + &certificateStatusMsg{}, + &clientKeyExchangeMsg{}, + &newSessionTicketMsg{}, + &sessionState{}, + &sessionStateTLS13{}, + &encryptedExtensionsMsg{}, + &endOfEarlyDataMsg{}, + &keyUpdateMsg{}, + &newSessionTicketMsgTLS13{}, + &certificateRequestMsgTLS13{}, + &certificateMsgTLS13{}, +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(time.Now().UnixNano())) + + for i, iface := range tests { + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("#%d: failed to create value", i) + break + } + + m1 := v.Interface().(handshakeMessage) + marshaled := m1.marshal() + m2 := iface.(handshakeMessage) + if !m2.unmarshal(marshaled) { + t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) + break + } + m2.marshal() // to fill any marshal cache in the message + + if !reflect.DeepEqual(m1, m2) { + t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) + break + } + + if i >= 3 { + // The first three message types (ClientHello, + // ServerHello and Finished) are allowed to + // have parsable prefixes because the extension + // data is optional and the length of the + // Finished varies across versions. + for j := 0; j < len(marshaled); j++ { + if m2.unmarshal(marshaled[0:j]) { + t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) + break + } + } + } + } + } +} + +func TestFuzz(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + for _, iface := range tests { + m := iface.(handshakeMessage) + + for j := 0; j < 1000; j++ { + len := rand.Intn(100) + bytes := randomBytes(len, rand) + // This just looks for crashes due to bounds errors etc. + m.unmarshal(bytes) + } + } +} + +func randomBytes(n int, rand *rand.Rand) []byte { + r := make([]byte, n) + if _, err := rand.Read(r); err != nil { + panic("rand.Read failed: " + err.Error()) + } + return r +} + +func randomString(n int, rand *rand.Rand) string { + b := randomBytes(n, rand) + return string(b) +} + +func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &clientHelloMsg{} + m.vers = uint16(rand.Intn(65536)) + m.random = randomBytes(32, rand) + m.sessionId = randomBytes(rand.Intn(32), rand) + m.cipherSuites = make([]uint16, rand.Intn(63)+1) + for i := 0; i < len(m.cipherSuites); i++ { + cs := uint16(rand.Int31()) + if cs == scsvRenegotiation { + cs += 1 + } + m.cipherSuites[i] = cs + } + m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) + if rand.Intn(10) > 5 { + m.serverName = randomString(rand.Intn(255), rand) + for strings.HasSuffix(m.serverName, ".") { + m.serverName = m.serverName[:len(m.serverName)-1] + } + } + m.ocspStapling = rand.Intn(10) > 5 + m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) + m.supportedCurves = make([]CurveID, rand.Intn(5)+1) + for i := range m.supportedCurves { + m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) + } + if rand.Intn(10) > 5 { + m.ticketSupported = true + if rand.Intn(10) > 5 { + m.sessionTicket = randomBytes(rand.Intn(300), rand) + } else { + m.sessionTicket = make([]byte, 0) + } + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() + } + for i := 0; i < rand.Intn(5); i++ { + m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) + } + if rand.Intn(10) > 5 { + m.scts = true + } + if rand.Intn(10) > 5 { + m.secureRenegotiationSupported = true + m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) + } + for i := 0; i < rand.Intn(5); i++ { + m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) + } + if rand.Intn(10) > 5 { + m.cookie = randomBytes(rand.Intn(500)+1, rand) + } + for i := 0; i < rand.Intn(5); i++ { + var ks keyShare + ks.group = CurveID(rand.Intn(30000) + 1) + ks.data = randomBytes(rand.Intn(200)+1, rand) + m.keyShares = append(m.keyShares, ks) + } + switch rand.Intn(3) { + case 1: + m.pskModes = []uint8{pskModeDHE} + case 2: + m.pskModes = []uint8{pskModeDHE, pskModePlain} + } + for i := 0; i < rand.Intn(5); i++ { + var psk pskIdentity + psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) + psk.label = randomBytes(rand.Intn(500)+1, rand) + m.pskIdentities = append(m.pskIdentities, psk) + m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) + } + if rand.Intn(10) > 5 { + m.earlyData = true + } + + return reflect.ValueOf(m) +} + +func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &serverHelloMsg{} + m.vers = uint16(rand.Intn(65536)) + m.random = randomBytes(32, rand) + m.sessionId = randomBytes(rand.Intn(32), rand) + m.cipherSuite = uint16(rand.Int31()) + m.compressionMethod = uint8(rand.Intn(256)) + m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) + + if rand.Intn(10) > 5 { + m.ocspStapling = true + } + if rand.Intn(10) > 5 { + m.ticketSupported = true + } + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } + + for i := 0; i < rand.Intn(4); i++ { + m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) + } + + if rand.Intn(10) > 5 { + m.secureRenegotiationSupported = true + m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) + } + if rand.Intn(10) > 5 { + m.supportedVersion = uint16(rand.Intn(0xffff) + 1) + } + if rand.Intn(10) > 5 { + m.cookie = randomBytes(rand.Intn(500)+1, rand) + } + if rand.Intn(10) > 5 { + for i := 0; i < rand.Intn(5); i++ { + m.serverShare.group = CurveID(rand.Intn(30000) + 1) + m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) + } + } else if rand.Intn(10) > 5 { + m.selectedGroup = CurveID(rand.Intn(30000) + 1) + } + if rand.Intn(10) > 5 { + m.selectedIdentityPresent = true + m.selectedIdentity = uint16(rand.Intn(0xffff)) + } + + return reflect.ValueOf(m) +} + +func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &encryptedExtensionsMsg{} + + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } + + return reflect.ValueOf(m) +} + +func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateMsg{} + numCerts := rand.Intn(20) + m.certificates = make([][]byte, numCerts) + for i := 0; i < numCerts; i++ { + m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) + } + return reflect.ValueOf(m) +} + +func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateRequestMsg{} + m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) + for i := 0; i < rand.Intn(100); i++ { + m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) + } + return reflect.ValueOf(m) +} + +func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateVerifyMsg{} + m.hasSignatureAlgorithm = true + m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) + m.signature = randomBytes(rand.Intn(15)+1, rand) + return reflect.ValueOf(m) +} + +func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateStatusMsg{} + m.response = randomBytes(rand.Intn(10)+1, rand) + return reflect.ValueOf(m) +} + +func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &clientKeyExchangeMsg{} + m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) + return reflect.ValueOf(m) +} + +func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &finishedMsg{} + m.verifyData = randomBytes(12, rand) + return reflect.ValueOf(m) +} + +func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &newSessionTicketMsg{} + m.ticket = randomBytes(rand.Intn(4), rand) + return reflect.ValueOf(m) +} + +func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { + s := &sessionState{} + s.vers = uint16(rand.Intn(10000)) + s.cipherSuite = uint16(rand.Intn(10000)) + s.masterSecret = randomBytes(rand.Intn(100)+1, rand) + s.createdAt = uint64(rand.Int63()) + for i := 0; i < rand.Intn(20); i++ { + s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) + } + return reflect.ValueOf(s) +} + +func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + s := &sessionStateTLS13{} + s.cipherSuite = uint16(rand.Intn(10000)) + s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) + s.createdAt = uint64(rand.Int63()) + for i := 0; i < rand.Intn(2)+1; i++ { + s.certificate.Certificate = append( + s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) + } + if rand.Intn(10) > 5 { + s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) + } + if rand.Intn(10) > 5 { + for i := 0; i < rand.Intn(2)+1; i++ { + s.certificate.SignedCertificateTimestamps = append( + s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) + } + } + return reflect.ValueOf(s) +} + +func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &endOfEarlyDataMsg{} + return reflect.ValueOf(m) +} + +func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &keyUpdateMsg{} + m.updateRequested = rand.Intn(10) > 5 + return reflect.ValueOf(m) +} + +func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &newSessionTicketMsgTLS13{} + m.lifetime = uint32(rand.Intn(500000)) + m.ageAdd = uint32(rand.Intn(500000)) + m.nonce = randomBytes(rand.Intn(100), rand) + m.label = randomBytes(rand.Intn(1000), rand) + if rand.Intn(10) > 5 { + m.maxEarlyData = uint32(rand.Intn(500000)) + } + return reflect.ValueOf(m) +} + +func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateRequestMsgTLS13{} + if rand.Intn(10) > 5 { + m.ocspStapling = true + } + if rand.Intn(10) > 5 { + m.scts = true + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() + } + if rand.Intn(10) > 5 { + m.certificateAuthorities = make([][]byte, 3) + for i := 0; i < 3; i++ { + m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) + } + } + return reflect.ValueOf(m) +} + +func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateMsgTLS13{} + for i := 0; i < rand.Intn(2)+1; i++ { + m.certificate.Certificate = append( + m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) + } + if rand.Intn(10) > 5 { + m.ocspStapling = true + m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) + } + if rand.Intn(10) > 5 { + m.scts = true + for i := 0; i < rand.Intn(2)+1; i++ { + m.certificate.SignedCertificateTimestamps = append( + m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) + } + } + return reflect.ValueOf(m) +} + +func TestRejectEmptySCTList(t *testing.T) { + // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. + + var random [32]byte + sct := []byte{0x42, 0x42, 0x42, 0x42} + serverHello := serverHelloMsg{ + vers: VersionTLS12, + random: random[:], + scts: [][]byte{sct}, + } + serverHelloBytes := serverHello.marshal() + + var serverHelloCopy serverHelloMsg + if !serverHelloCopy.unmarshal(serverHelloBytes) { + t.Fatal("Failed to unmarshal initial message") + } + + // Change serverHelloBytes so that the SCT list is empty + i := bytes.Index(serverHelloBytes, sct) + if i < 0 { + t.Fatal("Cannot find SCT in ServerHello") + } + + var serverHelloEmptySCT []byte + serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) + // Append the extension length and SCT list length for an empty list. + serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) + serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) + + // Update the handshake message length. + serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) + serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) + serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) + + // Update the extensions length + serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) + serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) + + if serverHelloCopy.unmarshal(serverHelloEmptySCT) { + t.Fatal("Unmarshaled ServerHello with empty SCT list") + } +} + +func TestRejectEmptySCT(t *testing.T) { + // Not only must the SCT list be non-empty, but the SCT elements must + // not be zero length. + + var random [32]byte + serverHello := serverHelloMsg{ + vers: VersionTLS12, + random: random[:], + scts: [][]byte{nil}, + } + serverHelloBytes := serverHello.marshal() + + var serverHelloCopy serverHelloMsg + if serverHelloCopy.unmarshal(serverHelloBytes) { + t.Fatal("Unmarshaled ServerHello with zero-length SCT") + } +} + +func TestRejectDuplicateExtensions(t *testing.T) { + clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f") + if err != nil { + t.Fatalf("failed to decode test ClientHello: %s", err) + } + var clientHelloCopy clientHelloMsg + if clientHelloCopy.unmarshal(clientHelloBytes) { + t.Error("Unmarshaled ClientHello with duplicate extensions") + } + + serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000") + if err != nil { + t.Fatalf("failed to decode test ServerHello: %s", err) + } + var serverHelloCopy serverHelloMsg + if serverHelloCopy.unmarshal(serverHelloBytes) { + t.Fatal("Unmarshaled ServerHello with duplicate extensions") + } +} diff --git a/pkg/tls/handshake_server.go b/pkg/tls/handshake_server.go index ede1c52e0..f4506ea63 100644 --- a/pkg/tls/handshake_server.go +++ b/pkg/tls/handshake_server.go @@ -5,6 +5,7 @@ package tls import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -13,6 +14,7 @@ import ( "crypto/x509" "errors" "fmt" + "hash" "io" "strconv" "time" @@ -22,6 +24,7 @@ import ( // It's discarded once the handshake has completed. type serverHandshakeState struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg suite *cipherSuite @@ -39,25 +42,28 @@ type serverHandshakeState struct { } // serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake() error { +func (c *Conn) serverHandshake(ctx context.Context) error { // If this is the first server handshake, we generate a random key to // encrypt the tickets with. //gnet不能进行阻塞二次读取,所以会分几条消息重复执行此方法,status也会分很多个状态 if c.hs == nil { //首次执行要初始化对象 - clientHello, err := c.readClientHello() + clientHello, err := c.readClientHello(ctx) if err != nil { return err } + if c.vers == VersionTLS13 { c.hs = &serverHandshakeStateTLS13{ c: c, + ctx: ctx, clientHello: clientHello, } } else { c.hs = &serverHandshakeState{ c: c, + ctx: ctx, clientHello: clientHello, } } @@ -142,7 +148,6 @@ func (hs *serverHandshakeState) handshake() error { return err } if err := hs.sendFinished(nil); err != nil { - return err } if _, err := c.flush(); err != nil { @@ -154,11 +159,12 @@ func (hs *serverHandshakeState) handshake() error { c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) c.handshakeStatus = 255 + return nil } // readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello() (*clientHelloMsg, error) { +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { msg, err := c.readHandshake() if err != nil { return nil, err @@ -172,7 +178,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) { var configForClient *Config originalConfig := c.config if c.config.GetConfigForClient != nil { - chi := clientHelloInfo(c, clientHello) + chi := clientHelloInfo(ctx, c, clientHello) if configForClient, err = c.config.GetConfigForClient(chi); err != nil { c.sendAlert(alertInternalError) return nil, err @@ -186,7 +192,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) { if len(clientHello.supportedVersions) == 0 { clientVersions = supportedVersionsFromMax(clientHello.vers) } - c.vers, ok = c.config.mutualVersion(clientVersions) + c.vers, ok = c.config.mutualVersion(roleServer, clientVersions) if !ok { c.sendAlert(alertProtocolVersion) return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) @@ -221,7 +227,7 @@ func (hs *serverHandshakeState) processClientHello() error { hs.hello.random = make([]byte, 32) serverRandom := hs.hello.random // Downgrade protection canaries. See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion() + maxVers := c.config.maxSupportedVersion(roleServer) if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { if c.vers == VersionTLS12 { copy(serverRandom[24:], downgradeCanaryTLS12) @@ -247,14 +253,15 @@ func (hs *serverHandshakeState) processClientHello() error { c.serverName = hs.clientHello.serverName } - if len(hs.clientHello.alpnProtocols) > 0 { - if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { - hs.hello.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - } + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + c.sendAlert(alertNoApplicationProtocol) + return err } + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto - hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) @@ -269,7 +276,7 @@ func (hs *serverHandshakeState) processClientHello() error { hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) - if hs.ecdheOk { + if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 { // Although omitting the ec_point_formats extension is permitted, some // old OpenSSL version will refuse to handshake if not present. // @@ -304,6 +311,34 @@ func (hs *serverHandshakeState) processClientHello() error { return nil } +// negotiateALPN picks a shared ALPN protocol that both sides support in server +// preference order. If ALPN is not configured or the peer doesn't support it, +// it returns "" and no error. +func negotiateALPN(serverProtos, clientProtos []string) (string, error) { + if len(serverProtos) == 0 || len(clientProtos) == 0 { + return "", nil + } + var http11fallback bool + for _, s := range serverProtos { + for _, c := range clientProtos { + if s == c { + return s, nil + } + if s == "h2" && c == "http/1.1" { + http11fallback = true + } + } + } + // As a special case, let http/1.1 clients connect to h2 servers as if they + // didn't support ALPN. We used not to enforce protocol overlap, so over + // time a number of HTTP servers were configured with only "h2", but + // expected to accept connections from "http/1.1" clients. See Issue 46310. + if http11fallback { + return "", nil + } + return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) +} + // supportsECDHE returns whether ECDHE key exchanges can be used with this // pre-TLS 1.3 client. func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool { @@ -322,6 +357,13 @@ func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8 break } } + // Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is + // missing, uncompressed points are supported. If supportedPoints is empty, + // the extension must be missing, as an empty extension body is rejected by + // the parser. See https://go.dev/issue/49126. + if len(supportedPoints) == 0 { + supportsPointFormat = true + } return supportsCurve && supportsPointFormat } @@ -329,16 +371,23 @@ func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8 func (hs *serverHandshakeState) pickCipherSuite() error { c := hs.c - var preferenceList, supportedList []uint16 - if c.config.PreferServerCipherSuites { - preferenceList = c.config.cipherSuites() - supportedList = hs.clientHello.cipherSuites - } else { - preferenceList = hs.clientHello.cipherSuites - supportedList = c.config.cipherSuites() + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceOrder = cipherSuitesPreferenceOrderNoAES } - hs.suite = selectCipherSuite(preferenceList, supportedList, hs.cipherSuiteOk) + configCipherSuites := c.config.cipherSuites() + preferenceList := make([]uint16, 0, len(configCipherSuites)) + for _, suiteID := range preferenceOrder { + for _, id := range configCipherSuites { + if id == suiteID { + preferenceList = append(preferenceList, id) + break + } + } + } + + hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) if hs.suite == nil { c.sendAlert(alertHandshakeFailure) return errors.New("tls: no cipher suite supported by both client and server") @@ -348,7 +397,7 @@ func (hs *serverHandshakeState) pickCipherSuite() error { for _, id := range hs.clientHello.cipherSuites { if id == TLS_FALLBACK_SCSV { // The client is doing a fallback connection. See RFC 7507. - if hs.clientHello.vers < c.config.maxSupportedVersion() { + if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) { c.sendAlert(alertInappropriateFallback) return errors.New("tls: client using inappropriate protocol fallback") } @@ -534,7 +583,7 @@ func (hs *serverHandshakeState) doFullHandshakeStep1() error { } if c.vers >= VersionTLS12 { hs.certReq.hasSignatureAlgorithm = true - hs.certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + hs.certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() } // An empty list of certificateAuthorities signals to @@ -565,6 +614,7 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { c := hs.c var pub crypto.PublicKey // public key for client auth, if any + msg, err := c.readHandshake() if err != nil { return err @@ -656,7 +706,7 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { } } - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { c.sendAlert(alertDecryptError) return errors.New("tls: invalid signature by the client certificate: " + err.Error()) @@ -677,13 +727,13 @@ func (hs *serverHandshakeState) establishKeys() error { keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) var clientCipher, serverCipher interface{} - var clientHash, serverHash macFunction + var clientHash, serverHash hash.Hash if hs.suite.aead == nil { clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) - clientHash = hs.suite.mac(c.vers, clientMAC) + clientHash = hs.suite.mac(clientMAC) serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) - serverHash = hs.suite.mac(c.vers, serverMAC) + serverHash = hs.suite.mac(serverMAC) } else { clientCipher = hs.suite.aead(clientKey, clientIV) serverCipher = hs.suite.aead(serverKey, serverIV) @@ -820,7 +870,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { chains, err := certs[0].Verify(opts) if err != nil { c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to verify client certificate: " + err.Error()) + return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} } c.verifiedChains = chains @@ -849,7 +899,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { return nil } -func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { +func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { supportedVersions := clientHello.supportedVersions if len(clientHello.supportedVersions) == 0 { supportedVersions = supportedVersionsFromMax(clientHello.vers) diff --git a/pkg/tls/handshake_server_tls13.go b/pkg/tls/handshake_server_tls13.go index 467612649..ee527036e 100644 --- a/pkg/tls/handshake_server_tls13.go +++ b/pkg/tls/handshake_server_tls13.go @@ -6,9 +6,11 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" + "encoding/binary" "errors" "hash" "io" @@ -22,6 +24,7 @@ const maxClientPSKIdentities = 5 type serverHandshakeStateTLS13 struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg sentDummyCCS bool @@ -40,6 +43,11 @@ type serverHandshakeStateTLS13 struct { func (hs *serverHandshakeStateTLS13) handshake() error { c := hs.c + + if needFIPS() { + return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") + } + switch c.handshakeStatus { case 0: // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. @@ -48,31 +56,25 @@ func (hs *serverHandshakeStateTLS13) handshake() error { return err } if err := hs.checkForResumption(); err != nil { - return err } if err := hs.pickCertificate(); err != nil { - return err } //c.buffering = true if err := hs.sendServerParameters(); err != nil { - return err } if err := hs.sendServerCertificate(); err != nil { - return err } if err := hs.sendServerFinished(); err != nil { - return err } // Note that at this point we could start sending application data without // waiting for the client's second flight, but the application might not // expect the lack of replay protection of the ClientHello parameters. if _, err := c.flush(); err != nil { - return err } c.handshakeStatus = 1 @@ -81,7 +83,6 @@ func (hs *serverHandshakeStateTLS13) handshake() error { return err } if err := hs.readClientFinished(); err != nil { - return err } c.handshakeStatus = 255 @@ -118,7 +119,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { if id == TLS_FALLBACK_SCSV { // Use c.vers instead of max(supported_versions) because an attacker // could defeat this by adding an arbitrary high version otherwise. - if c.vers < c.config.maxSupportedVersion() { + if c.vers < c.config.maxSupportedVersion(roleServer) { c.sendAlert(alertInappropriateFallback) return errors.New("tls: client using inappropriate protocol fallback") } @@ -157,16 +158,12 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { hs.hello.sessionId = hs.clientHello.sessionId hs.hello.compressionMethod = compressionNone - var preferenceList, supportedList []uint16 - if c.config.PreferServerCipherSuites { - preferenceList = defaultCipherSuitesTLS13() - supportedList = hs.clientHello.cipherSuites - } else { - preferenceList = hs.clientHello.cipherSuites - supportedList = defaultCipherSuitesTLS13() + preferenceList := defaultCipherSuitesTLS13 + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = defaultCipherSuitesTLS13NoAES } for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(supportedList, suiteID) + hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) if hs.suite != nil { break } @@ -213,18 +210,23 @@ GroupSelection: clientKeyShare = &hs.clientHello.keyShares[0] } - if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + if _, ok := curveForCurveID(selectedGroup); !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + key, err := generateECDHEKey(c.config.rand(), selectedGroup) if err != nil { c.sendAlert(alertInternalError) return err } - hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} - hs.sharedKey = params.SharedKey(clientKeyShare.data) - if hs.sharedKey == nil { + hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} + peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + hs.sharedKey, err = key.ECDH(peerKey) + if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid client key share") } @@ -371,7 +373,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error { return c.sendAlert(alertMissingExtension) } - certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) @@ -562,12 +564,13 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions := new(encryptedExtensionsMsg) - if len(hs.clientHello.alpnProtocols) > 0 { - if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { - encryptedExtensions.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - } + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + c.sendAlert(alertNoApplicationProtocol) + return err } + encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto hs.transcript.Write(encryptedExtensions.marshal()) if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil { @@ -594,7 +597,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certReq := new(certificateRequestMsgTLS13) certReq.ocspStapling = true certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() if c.config.ClientCAs != nil { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } @@ -752,6 +755,19 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { } m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 + // The value is not stored anywhere; we never need to check the ticket age + // because 0-RTT is not supported. + ageAdd := make([]byte, 4) + _, err = hs.c.config.rand().Read(ageAdd) + if err != nil { + return err + } + m.ageAdd = binary.LittleEndian.Uint32(ageAdd) + + // ticket_nonce, which must be unique per connection, is always left at + // zero because we only ever send one ticket per connection. + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { return err } @@ -813,7 +829,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { c.sendAlert(alertIllegalParameter) return errors.New("tls: client certificate used with invalid signature algorithm") } diff --git a/pkg/tls/handshake_test.go b/pkg/tls/handshake_test.go new file mode 100644 index 000000000..bacc8b7d4 --- /dev/null +++ b/pkg/tls/handshake_test.go @@ -0,0 +1,530 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bufio" + "crypto/ed25519" + "crypto/x509" + "encoding/hex" + "errors" + "flag" + "fmt" + "io" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +// TLS reference tests run a connection against a reference implementation +// (OpenSSL) of TLS and record the bytes of the resulting connection. The Go +// code, during a test, is configured with deterministic randomness and so the +// reference test can be reproduced exactly in the future. +// +// In order to save everyone who wishes to run the tests from needing the +// reference implementation installed, the reference connections are saved in +// files in the testdata directory. Thus running the tests involves nothing +// external, but creating and updating them requires the reference +// implementation. +// +// Tests can be updated by running them with the -update flag. This will cause +// the test files for failing tests to be regenerated. Since the reference +// implementation will always generate fresh random numbers, large parts of the +// reference connection will always change. + +var ( + update = flag.Bool("update", false, "update golden files on failure") + fast = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests") + keyFile = flag.String("keylog", "", "destination file for KeyLogWriter") +) + +func runTestAndUpdateIfNeeded(t *testing.T, name string, run func(t *testing.T, update bool), wait bool) { + success := t.Run(name, func(t *testing.T) { + if !*update && !wait { + t.Parallel() + } + run(t, false) + }) + + if !success && *update { + t.Run(name+"#update", func(t *testing.T) { + run(t, true) + }) + } +} + +// checkOpenSSLVersion ensures that the version of OpenSSL looks reasonable +// before updating the test data. +func checkOpenSSLVersion() error { + if !*update { + return nil + } + + openssl := exec.Command("openssl", "version") + output, err := openssl.CombinedOutput() + if err != nil { + return err + } + + version := string(output) + if strings.HasPrefix(version, "OpenSSL 1.1.1") { + return nil + } + + println("***********************************************") + println("") + println("You need to build OpenSSL 1.1.1 from source in order") + println("to update the test data.") + println("") + println("Configure it with:") + println("./Configure enable-weak-ssl-ciphers no-shared") + println("and then add the apps/ directory at the front of your PATH.") + println("***********************************************") + + return errors.New("version of OpenSSL does not appear to be suitable for updating test data") +} + +// recordingConn is a net.Conn that records the traffic that passes through it. +// WriteTo can be used to produce output that can be later be loaded with +// ParseTestData. +type recordingConn struct { + net.Conn + sync.Mutex + flows [][]byte + reading bool +} + +func (r *recordingConn) Read(b []byte) (n int, err error) { + if n, err = r.Conn.Read(b); n == 0 { + return + } + b = b[:n] + + r.Lock() + defer r.Unlock() + + if l := len(r.flows); l == 0 || !r.reading { + buf := make([]byte, len(b)) + copy(buf, b) + r.flows = append(r.flows, buf) + } else { + r.flows[l-1] = append(r.flows[l-1], b[:n]...) + } + r.reading = true + return +} + +func (r *recordingConn) Write(b []byte) (n int, err error) { + if n, err = r.Conn.Write(b); n == 0 { + return + } + b = b[:n] + + r.Lock() + defer r.Unlock() + + if l := len(r.flows); l == 0 || r.reading { + buf := make([]byte, len(b)) + copy(buf, b) + r.flows = append(r.flows, buf) + } else { + r.flows[l-1] = append(r.flows[l-1], b[:n]...) + } + r.reading = false + return +} + +// WriteTo writes Go source code to w that contains the recorded traffic. +func (r *recordingConn) WriteTo(w io.Writer) (int64, error) { + // TLS always starts with a client to server flow. + clientToServer := true + var written int64 + for i, flow := range r.flows { + source, dest := "client", "server" + if !clientToServer { + source, dest = dest, source + } + n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest) + written += int64(n) + if err != nil { + return written, err + } + dumper := hex.Dumper(w) + n, err = dumper.Write(flow) + written += int64(n) + if err != nil { + return written, err + } + err = dumper.Close() + if err != nil { + return written, err + } + clientToServer = !clientToServer + } + return written, nil +} + +func parseTestData(r io.Reader) (flows [][]byte, err error) { + var currentFlow []byte + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + // If the line starts with ">>> " then it marks the beginning + // of a new flow. + if strings.HasPrefix(line, ">>> ") { + if len(currentFlow) > 0 || len(flows) > 0 { + flows = append(flows, currentFlow) + currentFlow = nil + } + continue + } + + // Otherwise the line is a line of hex dump that looks like: + // 00000170 fc f5 06 bf (...) |.....X{&?......!| + // (Some bytes have been omitted from the middle section.) + _, after, ok := strings.Cut(line, " ") + if !ok { + return nil, errors.New("invalid test data") + } + line = after + + before, _, ok := strings.Cut(line, "|") + if !ok { + return nil, errors.New("invalid test data") + } + line = before + + hexBytes := strings.Fields(line) + for _, hexByte := range hexBytes { + val, err := strconv.ParseUint(hexByte, 16, 8) + if err != nil { + return nil, errors.New("invalid hex byte in test data: " + err.Error()) + } + currentFlow = append(currentFlow, byte(val)) + } + } + + if len(currentFlow) > 0 { + flows = append(flows, currentFlow) + } + + return flows, nil +} + +// tempFile creates a temp file containing contents and returns its path. +func tempFile(contents string) string { + file, err := os.CreateTemp("", "go-tls-test") + if err != nil { + panic("failed to create temp file: " + err.Error()) + } + path := file.Name() + file.WriteString(contents) + file.Close() + return path +} + +// localListener is set up by TestMain and used by localPipe to create Conn +// pairs like net.Pipe, but connected by an actual buffered TCP connection. +var localListener struct { + mu sync.Mutex + addr net.Addr + ch chan net.Conn +} + +const localFlakes = 0 // change to 1 or 2 to exercise localServer/localPipe handling of mismatches + +func localServer(l net.Listener) { + for n := 0; ; n++ { + c, err := l.Accept() + if err != nil { + return + } + if localFlakes == 1 && n%2 == 0 { + c.Close() + continue + } + localListener.ch <- c + } +} + +var isConnRefused = func(err error) bool { return false } + +func localPipe(t testing.TB) (net.Conn, net.Conn) { + localListener.mu.Lock() + defer localListener.mu.Unlock() + + addr := localListener.addr + + var err error +Dialing: + // We expect a rare mismatch, but probably not 5 in a row. + for i := 0; i < 5; i++ { + tooSlow := time.NewTimer(1 * time.Second) + defer tooSlow.Stop() + var c1 net.Conn + c1, err = net.Dial(addr.Network(), addr.String()) + if err != nil { + if runtime.GOOS == "dragonfly" && (isConnRefused(err) || os.IsTimeout(err)) { + // golang.org/issue/29583: Dragonfly sometimes returns a spurious + // ECONNREFUSED or ETIMEDOUT. + <-tooSlow.C + continue + } + t.Fatalf("localPipe: %v", err) + } + if localFlakes == 2 && i == 0 { + c1.Close() + continue + } + for { + select { + case <-tooSlow.C: + t.Logf("localPipe: timeout waiting for %v", c1.LocalAddr()) + c1.Close() + continue Dialing + + case c2 := <-localListener.ch: + if c2.RemoteAddr().String() == c1.LocalAddr().String() { + return c1, c2 + } + t.Logf("localPipe: unexpected connection: %v != %v", c2.RemoteAddr(), c1.LocalAddr()) + c2.Close() + } + } + } + + t.Fatalf("localPipe: failed to connect: %v", err) + panic("unreachable") +} + +// zeroSource is an io.Reader that returns an unlimited number of zero bytes. +type zeroSource struct{} + +func (zeroSource) Read(b []byte) (n int, err error) { + for i := range b { + b[i] = 0 + } + + return len(b), nil +} + +func allCipherSuites() []uint16 { + ids := make([]uint16, len(cipherSuites)) + for i, suite := range cipherSuites { + ids[i] = suite.id + } + + return ids +} + +var testConfig *Config + +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(runMain(m)) +} + +func runMain(m *testing.M) int { + // Cipher suites preferences change based on the architecture. Force them to + // the version without AES acceleration for test consistency. + hasAESGCMHardwareSupport = false + + // Set up localPipe. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + l, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err) + os.Exit(1) + } + localListener.ch = make(chan net.Conn) + localListener.addr = l.Addr() + defer l.Close() + go localServer(l) + + if err := checkOpenSSLVersion(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v", err) + os.Exit(1) + } + + testConfig = &Config{ + Time: func() time.Time { return time.Unix(0, 0) }, + Rand: zeroSource{}, + Certificates: make([]Certificate, 2), + InsecureSkipVerify: true, + CipherSuites: allCipherSuites(), + MinVersion: VersionTLS10, + MaxVersion: VersionTLS13, + } + testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate} + testConfig.Certificates[0].PrivateKey = testRSAPrivateKey + testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate} + testConfig.Certificates[1].PrivateKey = testRSAPrivateKey + testConfig.BuildNameToCertificate() + if *keyFile != "" { + f, err := os.OpenFile(*keyFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + panic("failed to open -keylog file: " + err.Error()) + } + testConfig.KeyLogWriter = f + defer f.Close() + } + + return m.Run() +} + +func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { + const sentinel = "SENTINEL\n" + c, s := localPipe(t) + errChan := make(chan error) + go func() { + cli := Client(c, clientConfig) + err := cli.Handshake() + if err != nil { + errChan <- fmt.Errorf("client: %v", err) + c.Close() + return + } + defer cli.Close() + clientState = cli.ConnectionState() + buf, err := io.ReadAll(cli) + if err != nil { + t.Errorf("failed to call cli.Read: %v", err) + } + if got := string(buf); got != sentinel { + t.Errorf("read %q from TLS connection, but expected %q", got, sentinel) + } + errChan <- nil + }() + server := Server(s, serverConfig) + err = server.Handshake() + if err == nil { + serverState = server.ConnectionState() + if _, err := io.WriteString(server, sentinel); err != nil { + t.Errorf("failed to call server.Write: %v", err) + } + if err := server.Close(); err != nil { + t.Errorf("failed to call server.Close: %v", err) + } + err = <-errChan + } else { + s.Close() + <-errChan + } + return +} + +func fromHex(s string) []byte { + b, _ := hex.DecodeString(s) + return b +} + +var testRSACertificate = fromHex("3082024b308201b4a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301a310b3009060355040a1302476f310b300906035504031302476f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a38193308190300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b30190603551d1104123010820e6578616d706c652e676f6c616e67300d06092a864886f70d01010b0500038181009d30cc402b5b50a061cbbae55358e1ed8328a9581aa938a495a1ac315a1a84663d43d32dd90bf297dfd320643892243a00bccf9c7db74020015faad3166109a276fd13c3cce10c5ceeb18782f16c04ed73bbb343778d0c1cf10fa1d8408361c94c722b9daedb4606064df4c1b33ec0d1bd42d4dbfe3d1360845c21d33be9fae7") + +var testRSACertificateIssuer = fromHex("3082021930820182a003020102020900ca5e4e811a965964300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f7430819f300d06092a864886f70d010101050003818d0030818902818100d667b378bb22f34143b6cd2008236abefaf2852adf3ab05e01329e2c14834f5105df3f3073f99dab5442d45ee5f8f57b0111c8cb682fbb719a86944eebfffef3406206d898b8c1b1887797c9c5006547bb8f00e694b7a063f10839f269f2c34fff7a1f4b21fbcd6bfdfb13ac792d1d11f277b5c5b48600992203059f2a8f8cc50203010001a35d305b300e0603551d0f0101ff040403020204301d0603551d250416301406082b0601050507030106082b06010505070302300f0603551d130101ff040530030101ff30190603551d0e041204104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b050003818100c1154b4bab5266221f293766ae4138899bd4c5e36b13cee670ceeaa4cbdf4f6679017e2fe649765af545749fe4249418a56bd38a04b81e261f5ce86b8d5c65413156a50d12449554748c59a30c515bc36a59d38bddf51173e899820b282e40aa78c806526fd184fb6b4cf186ec728edffa585440d2b3225325f7ab580e87dd76") + +// testRSAPSSCertificate has signatureAlgorithm rsassaPss, but subjectPublicKeyInfo +// algorithm rsaEncryption, for use with the rsa_pss_rsae_* SignatureSchemes. +// See also TestRSAPSSKeyError. testRSAPSSCertificate is self-signed. +var testRSAPSSCertificate = fromHex("308202583082018da003020102021100f29926eb87ea8a0db9fcc247347c11b0304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012030123110300e060355040a130741636d6520436f301e170d3137313132333136313631305a170d3138313132333136313631305a30123110300e060355040a130741636d6520436f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d110408300687047f000001304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012003818100cdac4ef2ce5f8d79881042707f7cbf1b5a8a00ef19154b40151771006cd41626e5496d56da0c1a139fd84695593cb67f87765e18aa03ea067522dd78d2a589b8c92364e12838ce346c6e067b51f1a7e6f4b37ffab13f1411896679d18e880e0ba09e302ac067efca460288e9538122692297ad8093d4f7dd701424d7700a46a1") + +var testECDSACertificate = fromHex("3082020030820162020900b8bf2d47a0d2ebf4300906072a8648ce3d04013045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3132313132323135303633325a170d3232313132303135303633325a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819b301006072a8648ce3d020106052b81040023038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b300906072a8648ce3d040103818c0030818802420188a24febe245c5487d1bacf5ed989dae4770c05e1bb62fbdf1b64db76140d311a2ceee0b7e927eff769dc33b7ea53fcefa10e259ec472d7cacda4e970e15a06fd00242014dfcbe67139c2d050ebd3fa38c25c13313830d9406bbd4377af6ec7ac9862eddd711697f857c56defb31782be4c7780daecbbe9e4e3624317b6a0f399512078f2a") + +var testEd25519Certificate = fromHex("3082012e3081e1a00302010202100f431c425793941de987e4f1ad15005d300506032b657030123110300e060355040a130741636d6520436f301e170d3139303531363231333830315a170d3230303531353231333830315a30123110300e060355040a130741636d6520436f302a300506032b65700321003fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8fa34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300506032b65700341006344ed9cc4be5324539fd2108d9fe82108909539e50dc155ff2c16b71dfcab7d4dd4e09313d0a942e0b66bfe5d6748d79f50bc6ccd4b03837cf20858cdaccf0c") + +var testSNICertificate = fromHex("0441883421114c81480804c430820237308201a0a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a3023310b3009060355040a1302476f311430120603550403130b736e69746573742e636f6d30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3773075300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b0500038181007beeecff0230dbb2e7a334af65430b7116e09f327c3bbf918107fc9c66cb497493207ae9b4dbb045cb63d605ec1b5dd485bb69124d68fa298dc776699b47632fd6d73cab57042acb26f083c4087459bc5a3bb3ca4d878d7fe31016b7bc9a627438666566e3389bfaeebe6becc9a0093ceed18d0f9ac79d56f3a73f18188988ed") + +var testP256Certificate = fromHex("308201693082010ea00302010202105012dc24e1124ade4f3e153326ff27bf300a06082a8648ce3d04030230123110300e060355040a130741636d6520436f301e170d3137303533313232343934375a170d3138303533313232343934375a30123110300e060355040a130741636d6520436f3059301306072a8648ce3d020106082a8648ce3d03010703420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d1104083006820474657374300a06082a8648ce3d0403020349003046022100963712d6226c7b2bef41512d47e1434131aaca3ba585d666c924df71ac0448b3022100f4d05c725064741aef125f243cdbccaa2a5d485927831f221c43023bd5ae471a") + +var testRSAPrivateKey, _ = x509.ParsePKCS1PrivateKey(fromHex("3082025b02010002818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d702030100010281800b07fbcf48b50f1388db34b016298b8217f2092a7c9a04f77db6775a3d1279b62ee9951f7e371e9de33f015aea80660760b3951dc589a9f925ed7de13e8f520e1ccbc7498ce78e7fab6d59582c2386cc07ed688212a576ff37833bd5943483b5554d15a0b9b4010ed9bf09f207e7e9805f649240ed6c1256ed75ab7cd56d9671024100fded810da442775f5923debae4ac758390a032a16598d62f059bb2e781a9c2f41bfa015c209f966513fe3bf5a58717cbdb385100de914f88d649b7d15309fa49024100dd10978c623463a1802c52f012cfa72ff5d901f25a2292446552c2568b1840e49a312e127217c2186615aae4fb6602a4f6ebf3f3d160f3b3ad04c592f65ae41f02400c69062ca781841a09de41ed7a6d9f54adc5d693a2c6847949d9e1358555c9ac6a8d9e71653ac77beb2d3abaf7bb1183aa14278956575dbebf525d0482fd72d90240560fe1900ba36dae3022115fd952f2399fb28e2975a1c3e3d0b679660bdcb356cc189d611cfdd6d87cd5aea45aa30a2082e8b51e94c2f3dd5d5c6036a8a615ed0240143993d80ece56f877cb80048335701eb0e608cc0c1ca8c2227b52edf8f1ac99c562f2541b5ce81f0515af1c5b4770dba53383964b4b725ff46fdec3d08907df")) + +var testECDSAPrivateKey, _ = x509.ParseECPrivateKey(fromHex("3081dc0201010442019883e909ad0ac9ea3d33f9eae661f1785206970f8ca9a91672f1eedca7a8ef12bd6561bb246dda5df4b4d5e7e3a92649bc5d83a0bf92972e00e62067d0c7bd99d7a00706052b81040023a18189038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b")) + +var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) + +var testEd25519PrivateKey = ed25519.PrivateKey(fromHex("3a884965e76b3f55e5faf9615458a92354894234de3ec9f684d46d55cebf3dc63fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8f")) + +const clientCertificatePEM = ` +-----BEGIN CERTIFICATE----- +MIIB7zCCAVigAwIBAgIQXBnBiWWDVW/cC8m5k5/pvDANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDgxNzIxNTIzMVoXDTE3MDgxNzIxNTIz +MVowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC +gYEAum+qhr3Pv5/y71yUYHhv6BPy0ZZvzdkybiI3zkH5yl0prOEn2mGi7oHLEMff +NFiVhuk9GeZcJ3NgyI14AvQdpJgJoxlwaTwlYmYqqyIjxXuFOE8uCXMyp70+m63K +hAfmDzr/d8WdQYUAirab7rCkPy1MTOZCPrtRyN1IVPQMjkcCAwEAAaNGMEQwDgYD +VR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAw +DwYDVR0RBAgwBocEfwAAATANBgkqhkiG9w0BAQsFAAOBgQBGq0Si+yhU+Fpn+GKU +8ZqyGJ7ysd4dfm92lam6512oFmyc9wnTN+RLKzZ8Aa1B0jLYw9KT+RBrjpW5LBeK +o0RIvFkTgxYEiKSBXCUNmAysEbEoVr4dzWFihAm/1oDGRY2CLLTYg5vbySK3KhIR +e/oCO8HJ/+rJnahJ05XX1Q7lNQ== +-----END CERTIFICATE-----` + +var clientKeyPEM = testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXQIBAAKBgQC6b6qGvc+/n/LvXJRgeG/oE/LRlm/N2TJuIjfOQfnKXSms4Sfa +YaLugcsQx980WJWG6T0Z5lwnc2DIjXgC9B2kmAmjGXBpPCViZiqrIiPFe4U4Ty4J +czKnvT6brcqEB+YPOv93xZ1BhQCKtpvusKQ/LUxM5kI+u1HI3UhU9AyORwIDAQAB +AoGAEJZ03q4uuMb7b26WSQsOMeDsftdatT747LGgs3pNRkMJvTb/O7/qJjxoG+Mc +qeSj0TAZXp+PXXc3ikCECAc+R8rVMfWdmp903XgO/qYtmZGCorxAHEmR80SrfMXv +PJnznLQWc8U9nphQErR+tTESg7xWEzmFcPKwnZd1xg8ERYkCQQDTGtrFczlB2b/Z +9TjNMqUlMnTLIk/a/rPE2fLLmAYhK5sHnJdvDURaH2mF4nso0EGtENnTsh6LATnY +dkrxXGm9AkEA4hXHG2q3MnhgK1Z5hjv+Fnqd+8bcbII9WW4flFs15EKoMgS1w/PJ +zbsySaSy5IVS8XeShmT9+3lrleed4sy+UwJBAJOOAbxhfXP5r4+5R6ql66jES75w +jUCVJzJA5ORJrn8g64u2eGK28z/LFQbv9wXgCwfc72R468BdawFSLa/m2EECQGbZ +rWiFla26IVXV0xcD98VWJsTBZMlgPnSOqoMdM1kSEd4fUmlAYI/dFzV1XYSkOmVr +FhdZnklmpVDeu27P4c0CQQCuCOup0FlJSBpWY1TTfun/KMBkBatMz0VMA3d7FKIU +csPezl677Yjo8u1r/KzeI6zLg87Z8E6r6ZWNc9wBSZK6 +-----END RSA TESTING KEY-----`) + +const clientECDSACertificatePEM = ` +-----BEGIN CERTIFICATE----- +MIIB/DCCAV4CCQCaMIRsJjXZFzAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw +EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 +eSBMdGQwHhcNMTIxMTE0MTMyNTUzWhcNMjIxMTEyMTMyNTUzWjBBMQswCQYDVQQG +EwJBVTEMMAoGA1UECBMDTlNXMRAwDgYDVQQHEwdQeXJtb250MRIwEAYDVQQDEwlK +b2VsIFNpbmcwgZswEAYHKoZIzj0CAQYFK4EEACMDgYYABACVjJF1FMBexFe01MNv +ja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd3kfDdq0Z9kUs +jLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx+U56jb0JuK7q +ixgnTy5w/hOWusPTQBbNZU6sER7m8TAJBgcqhkjOPQQBA4GMADCBiAJCAOAUxGBg +C3JosDJdYUoCdFzCgbkWqD8pyDbHgf9stlvZcPE4O1BIKJTLCRpS8V3ujfK58PDa +2RU6+b0DeoeiIzXsAkIBo9SKeDUcSpoj0gq+KxAxnZxfvuiRs9oa9V2jI/Umi0Vw +jWVim34BmT0Y9hCaOGGbLlfk+syxis7iI6CH8OFnUes= +-----END CERTIFICATE-----` + +var clientECDSAKeyPEM = testingKey(` +-----BEGIN EC PARAMETERS----- +BgUrgQQAIw== +-----END EC PARAMETERS----- +-----BEGIN EC TESTING KEY----- +MIHcAgEBBEIBkJN9X4IqZIguiEVKMqeBUP5xtRsEv4HJEtOpOGLELwO53SD78Ew8 +k+wLWoqizS3NpQyMtrU8JFdWfj+C57UNkOugBwYFK4EEACOhgYkDgYYABACVjJF1 +FMBexFe01MNvja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd +3kfDdq0Z9kUsjLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx ++U56jb0JuK7qixgnTy5w/hOWusPTQBbNZU6sER7m8Q== +-----END EC TESTING KEY-----`) + +const clientEd25519CertificatePEM = ` +-----BEGIN CERTIFICATE----- +MIIBLjCB4aADAgECAhAX0YGTviqMISAQJRXoNCNPMAUGAytlcDASMRAwDgYDVQQK +EwdBY21lIENvMB4XDTE5MDUxNjIxNTQyNloXDTIwMDUxNTIxNTQyNlowEjEQMA4G +A1UEChMHQWNtZSBDbzAqMAUGAytlcAMhAAvgtWC14nkwPb7jHuBQsQTIbcd4bGkv +xRStmmNveRKRo00wSzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH +AwIwDAYDVR0TAQH/BAIwADAWBgNVHREEDzANggtleGFtcGxlLmNvbTAFBgMrZXAD +QQD8GRcqlKUx+inILn9boF2KTjRAOdazENwZ/qAicbP1j6FYDc308YUkv+Y9FN/f +7Q7hF9gRomDQijcjKsJGqjoI +-----END CERTIFICATE-----` + +var clientEd25519KeyPEM = testingKey(` +-----BEGIN TESTING KEY----- +MC4CAQAwBQYDK2VwBCIEINifzf07d9qx3d44e0FSbV4mC/xQxT644RRbpgNpin7I +-----END TESTING KEY-----`) diff --git a/pkg/tls/handshake_unix_test.go b/pkg/tls/handshake_unix_test.go new file mode 100644 index 000000000..86a48f299 --- /dev/null +++ b/pkg/tls/handshake_unix_test.go @@ -0,0 +1,18 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix + +package tls + +import ( + "errors" + "syscall" +) + +func init() { + isConnRefused = func(err error) bool { + return errors.Is(err, syscall.ECONNREFUSED) + } +} diff --git a/pkg/tls/key_agreement.go b/pkg/tls/key_agreement.go index 7e6534bd4..2c8c5b8d7 100644 --- a/pkg/tls/key_agreement.go +++ b/pkg/tls/key_agreement.go @@ -6,6 +6,7 @@ package tls import ( "crypto" + "crypto/ecdh" "crypto/md5" "crypto/rsa" "crypto/sha1" @@ -15,6 +16,25 @@ import ( "io" ) +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") @@ -67,7 +87,11 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello return nil, nil, err } - encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret) + rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") + } + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) if err != nil { return nil, nil, err } @@ -134,7 +158,7 @@ func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint1 type ecdheKeyAgreement struct { version uint16 isRSA bool - params ecdheParameters + key *ecdh.PrivateKey // ckx and preMasterSecret are generated in processServerKeyExchange // and returned in generateClientKeyExchange. @@ -154,18 +178,18 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Cer if curveID == 0 { return nil, errors.New("tls: no supported elliptic curves offered") } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { return nil, errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), curveID) if err != nil { return nil, err } - ka.params = params + ka.key = key // See RFC 4492, Section 5.4. - ecdhePublic := params.PublicKey() + ecdhePublic := key.PublicKey().Bytes() serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) serverECDHEParams[0] = 3 // named curve serverECDHEParams[1] = byte(curveID >> 8) @@ -236,8 +260,12 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert return nil, errClientKeyExchange } - preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) - if preMasterSecret == nil { + peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:]) + if err != nil { + return nil, errClientKeyExchange + } + preMasterSecret, err := ka.key.ECDH(peerKey) + if err != nil { return nil, errClientKeyExchange } @@ -265,22 +293,26 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell return errServerKeyExchange } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { return errors.New("tls: server selected unsupported curve") } - params, err := generateECDHEParameters(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), curveID) if err != nil { return err } - ka.params = params + ka.key = key - ka.preMasterSecret = params.SharedKey(publicKey) - if ka.preMasterSecret == nil { + peerKey, err := key.Curve().NewPublicKey(publicKey) + if err != nil { + return errServerKeyExchange + } + ka.preMasterSecret, err = key.ECDH(peerKey) + if err != nil { return errServerKeyExchange } - ourPublicKey := params.PublicKey() + ourPublicKey := key.PublicKey().Bytes() ka.ckx = new(clientKeyExchangeMsg) ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) diff --git a/pkg/tls/key_schedule.go b/pkg/tls/key_schedule.go index 314016979..8150d804a 100644 --- a/pkg/tls/key_schedule.go +++ b/pkg/tls/key_schedule.go @@ -5,15 +5,13 @@ package tls import ( - "crypto/elliptic" + "crypto/ecdh" "crypto/hmac" "errors" "hash" "io" - "math/big" "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/curve25519" "golang.org/x/crypto/hkdf" ) @@ -101,99 +99,43 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript } } -// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman // according to RFC 8446, Section 4.2.8.2. -type ecdheParameters interface { - CurveID() CurveID - PublicKey() []byte - SharedKey(peerPublicKey []byte) []byte -} - -func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { - if curveID == X25519 { - privateKey := make([]byte, curve25519.ScalarSize) - if _, err := io.ReadFull(rand, privateKey); err != nil { - return nil, err - } - publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) - if err != nil { - return nil, err - } - return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil - } - +func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { curve, ok := curveForCurveID(curveID) if !ok { return nil, errors.New("tls: internal error: unsupported curve") } - p := &nistParameters{curveID: curveID} - var err error - p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) - if err != nil { - return nil, err - } - return p, nil + return curve.GenerateKey(rand) } -func curveForCurveID(id CurveID) (elliptic.Curve, bool) { +func curveForCurveID(id CurveID) (ecdh.Curve, bool) { switch id { + case X25519: + return ecdh.X25519(), true case CurveP256: - return elliptic.P256(), true + return ecdh.P256(), true case CurveP384: - return elliptic.P384(), true + return ecdh.P384(), true case CurveP521: - return elliptic.P521(), true + return ecdh.P521(), true default: return nil, false } } -type nistParameters struct { - privateKey []byte - x, y *big.Int // public key - curveID CurveID -} - -func (p *nistParameters) CurveID() CurveID { - return p.curveID -} - -func (p *nistParameters) PublicKey() []byte { - curve, _ := curveForCurveID(p.curveID) - return elliptic.Marshal(curve, p.x, p.y) -} - -func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { - curve, _ := curveForCurveID(p.curveID) - // Unmarshal also checks whether the given point is on the curve. - x, y := elliptic.Unmarshal(curve, peerPublicKey) - if x == nil { - return nil - } - - xShared, _ := curve.ScalarMult(x, y, p.privateKey) - sharedKey := make([]byte, (curve.Params().BitSize+7)/8) - return xShared.FillBytes(sharedKey) -} - -type x25519Parameters struct { - privateKey []byte - publicKey []byte -} - -func (p *x25519Parameters) CurveID() CurveID { - return X25519 -} - -func (p *x25519Parameters) PublicKey() []byte { - return p.publicKey[:] -} - -func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { - sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) - if err != nil { - return nil +func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) { + switch curve { + case ecdh.X25519(): + return X25519, true + case ecdh.P256(): + return CurveP256, true + case ecdh.P384(): + return CurveP384, true + case ecdh.P521(): + return CurveP521, true + default: + return 0, false } - return sharedKey } diff --git a/pkg/tls/key_schedule_test.go b/pkg/tls/key_schedule_test.go new file mode 100644 index 000000000..79ff6a62b --- /dev/null +++ b/pkg/tls/key_schedule_test.go @@ -0,0 +1,175 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "encoding/hex" + "hash" + "strings" + "testing" + "unicode" +) + +// This file contains tests derived from draft-ietf-tls-tls13-vectors-07. + +func parseVector(v string) []byte { + v = strings.Map(func(c rune) rune { + if unicode.IsSpace(c) { + return -1 + } + return c + }, v) + parts := strings.Split(v, ":") + v = parts[len(parts)-1] + res, err := hex.DecodeString(v) + if err != nil { + panic(err) + } + return res +} + +func TestDeriveSecret(t *testing.T) { + chTranscript := cipherSuitesTLS13[0].hash.New() + chTranscript.Write(parseVector(` + payload (512 octets): 01 00 01 fc 03 03 1b c3 ce b6 bb e3 9c ff + 93 83 55 b5 a5 0a db 6d b2 1b 7a 6a f6 49 d7 b4 bc 41 9d 78 76 + 48 7d 95 00 00 06 13 01 13 03 13 02 01 00 01 cd 00 00 00 0b 00 + 09 00 00 06 73 65 72 76 65 72 ff 01 00 01 00 00 0a 00 14 00 12 + 00 1d 00 17 00 18 00 19 01 00 01 01 01 02 01 03 01 04 00 33 00 + 26 00 24 00 1d 00 20 e4 ff b6 8a c0 5f 8d 96 c9 9d a2 66 98 34 + 6c 6b e1 64 82 ba dd da fe 05 1a 66 b4 f1 8d 66 8f 0b 00 2a 00 + 00 00 2b 00 03 02 03 04 00 0d 00 20 00 1e 04 03 05 03 06 03 02 + 03 08 04 08 05 08 06 04 01 05 01 06 01 02 01 04 02 05 02 06 02 + 02 02 00 2d 00 02 01 01 00 1c 00 02 40 01 00 15 00 57 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 29 00 dd 00 b8 00 b2 2c 03 5d 82 93 59 ee 5f f7 af 4e c9 00 + 00 00 00 26 2a 64 94 dc 48 6d 2c 8a 34 cb 33 fa 90 bf 1b 00 70 + ad 3c 49 88 83 c9 36 7c 09 a2 be 78 5a bc 55 cd 22 60 97 a3 a9 + 82 11 72 83 f8 2a 03 a1 43 ef d3 ff 5d d3 6d 64 e8 61 be 7f d6 + 1d 28 27 db 27 9c ce 14 50 77 d4 54 a3 66 4d 4e 6d a4 d2 9e e0 + 37 25 a6 a4 da fc d0 fc 67 d2 ae a7 05 29 51 3e 3d a2 67 7f a5 + 90 6c 5b 3f 7d 8f 92 f2 28 bd a4 0d da 72 14 70 f9 fb f2 97 b5 + ae a6 17 64 6f ac 5c 03 27 2e 97 07 27 c6 21 a7 91 41 ef 5f 7d + e6 50 5e 5b fb c3 88 e9 33 43 69 40 93 93 4a e4 d3 57 fa d6 aa + cb 00 21 20 3a dd 4f b2 d8 fd f8 22 a0 ca 3c f7 67 8e f5 e8 8d + ae 99 01 41 c5 92 4d 57 bb 6f a3 1b 9e 5f 9d`)) + + type args struct { + secret []byte + label string + transcript hash.Hash + } + tests := []struct { + name string + args args + want []byte + }{ + { + `derive secret for handshake "tls13 derived"`, + args{ + parseVector(`PRK (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c e2 + 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), + "derived", + nil, + }, + parseVector(`expanded (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba + b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), + }, + { + `derive secret "tls13 c e traffic"`, + args{ + parseVector(`PRK (32 octets): 9b 21 88 e9 b2 fc 6d 64 d7 1d c3 29 90 0e 20 bb + 41 91 50 00 f6 78 aa 83 9c bb 79 7c b7 d8 33 2c`), + "c e traffic", + chTranscript, + }, + parseVector(`expanded (32 octets): 3f bb e6 a6 0d eb 66 c3 0a 32 79 5a ba 0e + ff 7e aa 10 10 55 86 e7 be 5c 09 67 8d 63 b6 ca ab 62`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := cipherSuitesTLS13[0] + if got := c.deriveSecret(tt.args.secret, tt.args.label, tt.args.transcript); !bytes.Equal(got, tt.want) { + t.Errorf("cipherSuiteTLS13.deriveSecret() = % x, want % x", got, tt.want) + } + }) + } +} + +func TestTrafficKey(t *testing.T) { + trafficSecret := parseVector( + `PRK (32 octets): b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d 37 b4 + e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38`) + wantKey := parseVector( + `key expanded (16 octets): 3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e + e4 03 bc`) + wantIV := parseVector( + `iv expanded (12 octets): 5d 31 3e b2 67 12 76 ee 13 00 0b 30`) + + c := cipherSuitesTLS13[0] + gotKey, gotIV := c.trafficKey(trafficSecret) + if !bytes.Equal(gotKey, wantKey) { + t.Errorf("cipherSuiteTLS13.trafficKey() gotKey = % x, want % x", gotKey, wantKey) + } + if !bytes.Equal(gotIV, wantIV) { + t.Errorf("cipherSuiteTLS13.trafficKey() gotIV = % x, want % x", gotIV, wantIV) + } +} + +func TestExtract(t *testing.T) { + type args struct { + newSecret []byte + currentSecret []byte + } + tests := []struct { + name string + args args + want []byte + }{ + { + `extract secret "early"`, + args{ + nil, + nil, + }, + parseVector(`secret (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c + e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), + }, + { + `extract secret "master"`, + args{ + nil, + parseVector(`salt (32 octets): 43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25 90 b5 + 31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4`), + }, + parseVector(`secret (32 octets): 18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a + 47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19`), + }, + { + `extract secret "handshake"`, + args{ + parseVector(`IKM (32 octets): 8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d + 35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d`), + parseVector(`salt (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97 + 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), + }, + parseVector(`secret (32 octets): 1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b + 01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := cipherSuitesTLS13[0] + if got := c.extract(tt.args.newSecret, tt.args.currentSecret); !bytes.Equal(got, tt.want) { + t.Errorf("cipherSuiteTLS13.extract() = % x, want % x", got, tt.want) + } + }) + } +} diff --git a/pkg/tls/notboring.go b/pkg/tls/notboring.go new file mode 100644 index 000000000..7d85b39c5 --- /dev/null +++ b/pkg/tls/notboring.go @@ -0,0 +1,20 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !boringcrypto + +package tls + +func needFIPS() bool { return false } + +func supportedSignatureAlgorithms() []SignatureScheme { + return defaultSupportedSignatureAlgorithms +} + +func fipsMinVersion(c *Config) uint16 { panic("fipsMinVersion") } +func fipsMaxVersion(c *Config) uint16 { panic("fipsMaxVersion") } +func fipsCurvePreferences(c *Config) []CurveID { panic("fipsCurvePreferences") } +func fipsCipherSuites(c *Config) []uint16 { panic("fipsCipherSuites") } + +var fipsSupportedSignatureAlgorithms []SignatureScheme diff --git a/pkg/tls/prf.go b/pkg/tls/prf.go index 13bfa009c..b60166dee 100644 --- a/pkg/tls/prf.go +++ b/pkg/tls/prf.go @@ -215,7 +215,7 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte { // hashForClientCertificate returns the handshake messages so far, pre-hashed if // necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") } diff --git a/pkg/tls/prf_test.go b/pkg/tls/prf_test.go new file mode 100644 index 000000000..8233985a6 --- /dev/null +++ b/pkg/tls/prf_test.go @@ -0,0 +1,140 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "encoding/hex" + "testing" +) + +type testSplitPreMasterSecretTest struct { + in, out1, out2 string +} + +var testSplitPreMasterSecretTests = []testSplitPreMasterSecretTest{ + {"", "", ""}, + {"00", "00", "00"}, + {"0011", "00", "11"}, + {"001122", "0011", "1122"}, + {"00112233", "0011", "2233"}, +} + +func TestSplitPreMasterSecret(t *testing.T) { + for i, test := range testSplitPreMasterSecretTests { + in, _ := hex.DecodeString(test.in) + out1, out2 := splitPreMasterSecret(in) + s1 := hex.EncodeToString(out1) + s2 := hex.EncodeToString(out2) + if s1 != test.out1 || s2 != test.out2 { + t.Errorf("#%d: got: (%s, %s) want: (%s, %s)", i, s1, s2, test.out1, test.out2) + } + } +} + +type testKeysFromTest struct { + version uint16 + suite *cipherSuite + preMasterSecret string + clientRandom, serverRandom string + masterSecret string + clientMAC, serverMAC string + clientKey, serverKey string + macLen, keyLen int + contextKeyingMaterial, noContextKeyingMaterial string +} + +func TestKeysFromPreMasterSecret(t *testing.T) { + for i, test := range testKeysFromTests { + in, _ := hex.DecodeString(test.preMasterSecret) + clientRandom, _ := hex.DecodeString(test.clientRandom) + serverRandom, _ := hex.DecodeString(test.serverRandom) + + masterSecret := masterFromPreMasterSecret(test.version, test.suite, in, clientRandom, serverRandom) + if s := hex.EncodeToString(masterSecret); s != test.masterSecret { + t.Errorf("#%d: bad master secret %s, want %s", i, s, test.masterSecret) + continue + } + + clientMAC, serverMAC, clientKey, serverKey, _, _ := keysFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom, test.macLen, test.keyLen, 0) + clientMACString := hex.EncodeToString(clientMAC) + serverMACString := hex.EncodeToString(serverMAC) + clientKeyString := hex.EncodeToString(clientKey) + serverKeyString := hex.EncodeToString(serverKey) + if clientMACString != test.clientMAC || + serverMACString != test.serverMAC || + clientKeyString != test.clientKey || + serverKeyString != test.serverKey { + t.Errorf("#%d: got: (%s, %s, %s, %s) want: (%s, %s, %s, %s)", i, clientMACString, serverMACString, clientKeyString, serverKeyString, test.clientMAC, test.serverMAC, test.clientKey, test.serverKey) + } + + ekm := ekmFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom) + contextKeyingMaterial, err := ekm("label", []byte("context"), 32) + if err != nil { + t.Fatalf("ekmFromMasterSecret failed: %v", err) + } + + noContextKeyingMaterial, err := ekm("label", nil, 32) + if err != nil { + t.Fatalf("ekmFromMasterSecret failed: %v", err) + } + + if hex.EncodeToString(contextKeyingMaterial) != test.contextKeyingMaterial || + hex.EncodeToString(noContextKeyingMaterial) != test.noContextKeyingMaterial { + t.Errorf("#%d: got keying material: (%s, %s) want: (%s, %s)", i, contextKeyingMaterial, noContextKeyingMaterial, test.contextKeyingMaterial, test.noContextKeyingMaterial) + } + } +} + +// These test vectors were generated from GnuTLS using `gnutls-cli --insecure -d 9 ` +var testKeysFromTests = []testKeysFromTest{ + { + VersionTLS10, + cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), + "0302cac83ad4b1db3b9ab49ad05957de2a504a634a386fc600889321e1a971f57479466830ac3e6f468e87f5385fa0c5", + "4ae66303755184a3917fcb44880605fcc53baa01912b22ed94473fc69cebd558", + "4ae663020ec16e6bb5130be918cfcafd4d765979a3136a5d50c593446e4e44db", + "3d851bab6e5556e959a16bc36d66cfae32f672bfa9ecdef6096cbb1b23472df1da63dbbd9827606413221d149ed08ceb", + "805aaa19b3d2c0a0759a4b6c9959890e08480119", + "2d22f9fe519c075c16448305ceee209fc24ad109", + "d50b5771244f850cd8117a9ccafe2cf1", + "e076e33206b30507a85c32855acd0919", + 20, + 16, + "4d1bb6fc278c37d27aa6e2a13c2e079095d143272c2aa939da33d88c1c0cec22", + "93fba89599b6321ae538e27c6548ceb8b46821864318f5190d64a375e5d69d41", + }, + { + VersionTLS10, + cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), + "03023f7527316bc12cbcd69e4b9e8275d62c028f27e65c745cfcddc7ce01bd3570a111378b63848127f1c36e5f9e4890", + "4ae66364b5ea56b20ce4e25555aed2d7e67f42788dd03f3fee4adae0459ab106", + "4ae66363ab815cbf6a248b87d6b556184e945e9b97fbdf247858b0bdafacfa1c", + "7d64be7c80c59b740200b4b9c26d0baaa1c5ae56705acbcf2307fe62beb4728c19392c83f20483801cce022c77645460", + "97742ed60a0554ca13f04f97ee193177b971e3b0", + "37068751700400e03a8477a5c7eec0813ab9e0dc", + "207cddbc600d2a200abac6502053ee5c", + "df3f94f6e1eacc753b815fe16055cd43", + 20, + 16, + "2c9f8961a72b97cbe76553b5f954caf8294fc6360ef995ac1256fe9516d0ce7f", + "274f19c10291d188857ad8878e2119f5aa437d4da556601cf1337aff23154016", + }, + { + VersionTLS10, + cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), + "832d515f1d61eebb2be56ba0ef79879efb9b527504abb386fb4310ed5d0e3b1f220d3bb6b455033a2773e6d8bdf951d278a187482b400d45deb88a5d5a6bb7d6a7a1decc04eb9ef0642876cd4a82d374d3b6ff35f0351dc5d411104de431375355addc39bfb1f6329fb163b0bc298d658338930d07d313cd980a7e3d9196cac1", + "4ae663b2ee389c0de147c509d8f18f5052afc4aaf9699efe8cb05ece883d3a5e", + "4ae664d503fd4cff50cfc1fb8fc606580f87b0fcdac9554ba0e01d785bdf278e", + "1aff2e7a2c4279d0126f57a65a77a8d9d0087cf2733366699bec27eb53d5740705a8574bb1acc2abbe90e44f0dd28d6c", + "3c7647c93c1379a31a609542aa44e7f117a70085", + "0d73102994be74a575a3ead8532590ca32a526d4", + "ac7581b0b6c10d85bbd905ffbf36c65e", + "ff07edde49682b45466bd2e39464b306", + 20, + 16, + "678b0d43f607de35241dc7e9d1a7388a52c35033a1a0336d4d740060a6638fe2", + "f3b4ac743f015ef21d79978297a53da3e579ee047133f38c234d829c0f907dab", + }, +} diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index 5ede0b614..0dd65b0c2 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -21,49 +21,76 @@ import ( "encoding/pem" "errors" "fmt" - "io/ioutil" "net" + "os" "strings" "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" ) -type conn interface { - Write([]byte) (int, error) - RemoteAddr() net.Addr -} - // Server returns a new TLS server side connection // using conn as the underlying transport. // The configuration config must be non-nil and must include // at least one certificate or else set GetCertificate. -func Server(c conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) (*Conn, error) { - tlsconn := &Conn{ - conn: c, +func ServerGnet(conn net.Conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) *Conn { + c := &Conn{ + conn: conn, config: config, input: in, sendBuf: out, - outBuf: []byte{0, 3, 3, 0, 0}, } - - return tlsconn, nil + c.handshakeFn = c.serverHandshake + return c } -func Client(c conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) *Conn { - tlsconn := &Conn{ - conn: c, + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func ClientGnet(conn net.Conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) *Conn { + c := &Conn{ + conn: conn, config: config, input: in, sendBuf: out, - outBuf: []byte{0, 3, 3, 0, 0}, isClient: true, } - return tlsconn + c.handshakeFn = c.clientHandshake + return c +} + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config) *Conn { + sendBuf, _ := elastic.New(65536) + c := &Conn{ + conn: conn, + config: config, + input: new(elastic.RingBuffer), + sendBuf: sendBuf, + } + c.handshakeFn = c.serverHandshake + return c } // Client returns a new TLS client side connection // using conn as the underlying transport. // The config cannot be nil: users must set either ServerName or // InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + sendBuf, _ := elastic.New(65536) + c := &Conn{ + conn: conn, + config: config, + input: new(elastic.RingBuffer), + sendBuf: sendBuf, + isClient: true, + } + c.handshakeFn = c.clientHandshake + return c +} type timeoutError struct{} @@ -77,11 +104,11 @@ func (timeoutError) Temporary() bool { return true } // form a certificate chain. On successful return, Certificate.Leaf will // be nil because the parsed form of the certificate is not retained. func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { - certPEMBlock, err := ioutil.ReadFile(certFile) + certPEMBlock, err := os.ReadFile(certFile) if err != nil { return Certificate{}, err } - keyPEMBlock, err := ioutil.ReadFile(keyFile) + keyPEMBlock, err := os.ReadFile(keyFile) if err != nil { return Certificate{}, err } @@ -183,7 +210,7 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { } // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates -// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys. +// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { diff --git a/pkg/tls/tls_test.go b/pkg/tls/tls_test.go new file mode 100644 index 000000000..0eb33fc77 --- /dev/null +++ b/pkg/tls/tls_test.go @@ -0,0 +1,25 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "strings" +) + +var rsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ +hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa +rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv +zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW +r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V +-----END CERTIFICATE----- +` + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } From fe87eebeae274f5883ff507c4c18783f8aa79c0d Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Fri, 20 Jan 2023 16:49:30 +0000 Subject: [PATCH 03/34] delete unsed file internal/boring/rand.go --- internal/boring/rand.go | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 internal/boring/rand.go diff --git a/internal/boring/rand.go b/internal/boring/rand.go deleted file mode 100644 index 7639c0190..000000000 --- a/internal/boring/rand.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build boringcrypto && linux && (amd64 || arm64) && !android && !cmd_go_bootstrap && !msan - -package boring - -// #include "goboringcrypto.h" -import "C" -import "unsafe" - -type randReader int - -func (randReader) Read(b []byte) (int, error) { - // Note: RAND_bytes should never fail; the return value exists only for historical reasons. - // We check it even so. - if len(b) > 0 && C._goboringcrypto_RAND_bytes((*C.uint8_t)(unsafe.Pointer(&b[0])), C.size_t(len(b))) == 0 { - return 0, fail("RAND_bytes") - } - return len(b), nil -} - -const RandReader = randReader(0) From 7c5336a16e6c9d9d1d3a80306c9aa4a565e52b7c Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Fri, 20 Jan 2023 23:03:33 +0000 Subject: [PATCH 04/34] Memory optimization: add the elastic wrapper EMsgBuffer to MsgBuffer so that the tls conn not longer holds the actual buffer when the connection is idle. Other updates: 1. add defaultSize in MsgBuffer 2. fix the condition to clean up the buffer (i > blockSize to i >= blockSize) --- pkg/tls/buf.go | 132 ++++++++++++----------------------- pkg/tls/bufElastic.go | 156 ++++++++++++++++++++++++++++++++++++++++++ pkg/tls/conn.go | 18 +++-- 3 files changed, 212 insertions(+), 94 deletions(-) create mode 100644 pkg/tls/bufElastic.go diff --git a/pkg/tls/buf.go b/pkg/tls/buf.go index b9ca83d5b..0fafc4172 100644 --- a/pkg/tls/buf.go +++ b/pkg/tls/buf.go @@ -7,16 +7,18 @@ import ( type MsgBuffer struct { b []byte - l int //长度 - i int //起点位置 + l int // Total length of buffered data + i int // Position of unread buffered data } const ( - blocksize = 1024 * 5 //清理失效数据阈值 - appendsize = 4096 + blockSize = 8192 // clean up the data when i >= blocksize + appendSize = 4096 + defaultSize = 4096 ) -func NewBuffer(n int) *MsgBuffer { +// New returns a new MsgBuffer whose buffer has the given size. +func NewMsgBuffer(n int) *MsgBuffer { return &MsgBuffer{b: make([]byte, 0, n)} } @@ -25,96 +27,64 @@ func (w *MsgBuffer) Reset() { w.i = 0 } -func (w *MsgBuffer) Make(l int) []byte { - if w.i > blocksize { +// clean up the data when i >= blockSize +func (w *MsgBuffer) clean() { + if w.i >= blockSize { copy(w.b[:w.l-w.i], w.b[w.i:w.l]) w.l -= w.i w.i = 0 } - o := w.l - w.l += l - if len(w.b) < w.l { //扩容 +} + +// grow the buffer size if the size of current buffer cannot fit the new incoming data. +func (w *MsgBuffer) grow() { + if len(w.b) < w.l { if cap(w.b) < w.l { add := w.l - len(w.b) - if add > appendsize { + if add > appendSize { w.b = append(w.b, make([]byte, add)...) } else { - w.b = append(w.b, make([]byte, appendsize)...) + w.b = append(w.b, make([]byte, appendSize)...) } } w.b = w.b[:w.l] } +} + +func (w *MsgBuffer) Make(l int) []byte { + w.clean() + o := w.l + w.l += l + w.grow() return w.b[o:w.l] } func (w *MsgBuffer) Write(b []byte) (int, error) { - if w.i > blocksize { - copy(w.b[:w.l-w.i], w.b[w.i:w.l]) - w.l -= w.i - w.i = 0 - } + w.clean() l := len(b) o := w.l w.l += l - if len(w.b) < w.l { - if cap(w.b) < w.l { - add := w.l - len(w.b) - if add > appendsize { - w.b = append(w.b, make([]byte, add)...) - } else { - w.b = append(w.b, make([]byte, appendsize)...) - } - } - w.b = w.b[:w.l] - } + w.grow() copy(w.b[o:w.l], b) return l, nil } func (w *MsgBuffer) WriteString(s string) { - if w.i > blocksize { - copy(w.b[:w.l-w.i], w.b[w.i:w.l]) - w.l -= w.i - w.i = 0 - } + w.clean() x := (*[2]uintptr)(unsafe.Pointer(&s)) h := [3]uintptr{x[0], x[1], x[1]} b := *(*[]byte)(unsafe.Pointer(&h)) l := len(b) o := w.l w.l += l - if len(w.b) < w.l { //扩容 - if cap(w.b) < w.l { - add := w.l - len(w.b) - if add > appendsize { - w.b = append(w.b, make([]byte, add)...) - } else { - w.b = append(w.b, make([]byte, appendsize)...) - } - } - w.b = w.b[:w.l] - } + w.grow() copy(w.b[o:w.l], b) } func (w *MsgBuffer) WriteByte(s byte) error { - if w.i > blocksize { - copy(w.b[:w.l-w.i], w.b[w.i:w.l]) - w.l -= w.i - w.i = 0 - } + w.clean() w.l++ - if len(w.b) < w.l { - if cap(w.b) < w.l { - add := w.l - len(w.b) - if add > appendsize { - w.b = append(w.b, make([]byte, add)...) - } else { - w.b = append(w.b, make([]byte, appendsize)...) - } - } - w.b = w.b[:w.l] - } + w.grow() w.b[w.l-1] = s return nil @@ -124,7 +94,7 @@ func (w *MsgBuffer) Bytes() []byte { return w.b[w.i:w.l] } -func (w *MsgBuffer) PreBytes(n int) []byte { +func (w *MsgBuffer) Peek(n int) []byte { end := w.i + n if end > w.l { end = w.l @@ -136,17 +106,11 @@ func (w *MsgBuffer) Len() int { return w.l - w.i } -func (w *MsgBuffer) Next(l int) []byte { - o := w.i - w.i += l - if w.i > w.l { - w.i = w.l - } - return w.b[o:w.i] -} - func (w *MsgBuffer) Truncate(i int) { - w.l = w.i + i + l := w.i + i + if l < w.l { + w.l = l + } } func (w *MsgBuffer) String() string { @@ -155,24 +119,13 @@ func (w *MsgBuffer) String() string { return *(*string)(unsafe.Pointer(&b)) } -// New returns a new MsgBuffer whose buffer has the given size. -func New(size int) *MsgBuffer { - - return &MsgBuffer{ - b: make([]byte, size), - } -} - -// Shift shifts the "read" pointer. -func (r *MsgBuffer) Shift(len int) { - if len <= 0 { +// Discard skips the next n bytes by advancing the read pointer. +func (r *MsgBuffer) Discard(l int) { + if l <= 0 { return } - if len < r.Len() { - r.i += len - if r.i > r.l { - r.i = r.l - } + if l < r.Len() { + r.i += l } else { r.Reset() } @@ -207,3 +160,8 @@ func (r *MsgBuffer) ReadByte() (b byte, err error) { r.i++ return b, err } + +// IsEmpty tells if this MsgBuffer is empty. +func (b *MsgBuffer) IsEmpty() bool { + return b.i == b.l +} diff --git a/pkg/tls/bufElastic.go b/pkg/tls/bufElastic.go new file mode 100644 index 000000000..2d0d1c17b --- /dev/null +++ b/pkg/tls/bufElastic.go @@ -0,0 +1,156 @@ +package tls + +import ( + "io" + "sync" +) + +// EMsgBuffer is the elastic wrapper of EMsgBuffer. +type EMsgBuffer struct { + mb *MsgBuffer +} + +var msgBufferPool = sync.Pool{ + New: func() any { + return NewMsgBuffer(defaultSize) + }, +} + +func (b *EMsgBuffer) instance() *MsgBuffer { + if b.mb == nil { + b.mb = msgBufferPool.New().(*MsgBuffer) + } + return b.mb +} + +// Done checks and returns the internal MsgBuffer to pool. +func (b *EMsgBuffer) Done() { + if b.mb != nil { + b.mb.Reset() + msgBufferPool.Put(b.mb) + b.mb = nil + } +} + +func (b *EMsgBuffer) DoneIfEmpty() { + b.done() +} + +func (b *EMsgBuffer) done() { + if b.mb != nil && b.mb.IsEmpty() { + b.mb.Reset() + msgBufferPool.Put(b.mb) + b.mb = nil + } +} + +func (b *EMsgBuffer) Make(l int) []byte { + return b.instance().Make(l) +} + +// Write writes len(p) bytes from p to the underlying buf. +func (b *EMsgBuffer) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + return b.instance().Write(p) +} + +// WriteString writes the contents of the string s to buffer, which accepts a slice of bytes. +func (b *EMsgBuffer) WriteString(s string) { + b.instance().WriteString(s) +} + +// WriteByte writes one byte into buffer. +func (b *EMsgBuffer) WriteByte(c byte) error { + return b.instance().WriteByte(c) +} + +// Bytes returns all available read bytes. It does not move the read pointer and only copy the available data. +func (b *EMsgBuffer) Bytes() []byte { + if b.mb == nil { + return nil + } + return b.mb.Bytes() +} + +// Bytes returns first n readable bytes. It does not move the read pointer and only copy the available data. +func (b *EMsgBuffer) Peek(n int) []byte { + if b.mb == nil { + return nil + } + return b.mb.Peek(n) +} + +// Len returns the length of the underlying buffer. +func (b *EMsgBuffer) Len() int { + if b.mb == nil { + return 0 + } + return b.mb.Len() +} + +// truncate the total number of readable bytes to i +func (b *EMsgBuffer) Truncate(i int) { + if b.mb != nil { + b.mb.Truncate(i) + b.done() + } +} + +func (b *EMsgBuffer) String() string { + if b.mb == nil { + return "" + } + return b.mb.String() +} + +// Discard skips the next n bytes by advancing the read pointer. +func (b *EMsgBuffer) Discard(l int) { + if b.mb != nil { + b.mb.Discard(l) + b.done() + } +} + +// Discard skips the next n bytes by advancing the read pointer, but holding the MsgBuffer temporarily. +// Doing so can ensure one can use the data return by Peek not used by another thread. +// Therefore, thread-safe is guaranteed. +func (b *EMsgBuffer) DiscardWithoutDone(l int) { + if b.mb != nil { + b.mb.Discard(l) + } +} + +func (b *EMsgBuffer) Close() error { + if b.mb == nil { + return nil + } + return b.Close() + +} + +func (b *EMsgBuffer) Read(p []byte) (n int, err error) { + if b.mb == nil { + return 0, io.EOF + } + defer b.done() + return b.mb.Read(p) +} + +// ReadByte reads and returns the next byte from the input or ErrIsEmpty. +func (b *EMsgBuffer) ReadByte() (byte, error) { + if b.mb == nil { + return 0, io.EOF + } + defer b.done() + return b.mb.ReadByte() +} + +// IsEmpty tells if this MsgBuffer is empty. +func (b *EMsgBuffer) IsEmpty() bool { + if b.mb == nil { + return true + } + return b.mb.IsEmpty() +} diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index fb7af6969..46a2de1e1 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -95,12 +95,15 @@ type Conn struct { clientProtocol string // input/output + // By using the elastic MsgBuffer the tls conn not longer holds the actual buffer when the connection is idle. + // This can significantly optimize the memory usage, especially when the server connecting millions of clients + // where most of them are idle. in, out halfConn - rawInput MsgBuffer // raw input, starting with a record header + rawInput EMsgBuffer // raw input, starting with a record header input *elastic.RingBuffer // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn - hand MsgBuffer // handshake data waiting to be read + hand EMsgBuffer // handshake data waiting to be read // buffering bool // whether records are buffered in sendBuf - sendBuf *elastic.Buffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn + sendBuf *elastic.Buffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn // bytesSent counts the bytes of application data sent. // packetsSent counts packets. @@ -657,7 +660,8 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { } // Process message. - c.rawInput.Shift(recordHeaderLen + n) + c.rawInput.DiscardWithoutDone(recordHeaderLen + n) + defer c.rawInput.DoneIfEmpty() data, typ, err := c.in.decrypt(hdr[:recordHeaderLen+n]) if err != nil { return c.in.setErrorLocked(c.sendAlert(err.(alert))) @@ -968,7 +972,7 @@ func (c *Conn) readHandshake() (interface{}, error) { } } - data := c.hand.PreBytes(4) + data := c.hand.Peek(4) n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if n > maxHandshake { c.sendAlertLocked(alertInternalError) @@ -979,7 +983,8 @@ func (c *Conn) readHandshake() (interface{}, error) { return nil, err } } - data = c.hand.Next(4 + n) + data = c.hand.Peek(4 + n) + defer c.hand.Discard(4 + n) var m handshakeMessage switch data[0] { case typeHelloRequest: @@ -1113,7 +1118,6 @@ func (c *Conn) RawData() []byte { return c.rawInput.Bytes() } - // handleRenegotiation processes a HelloRequest handshake message. func (c *Conn) handleRenegotiation() error { if c.vers == VersionTLS13 { From 3394893be60053f7a036f249a643846c8c794c3e Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 23 Jan 2023 16:35:02 +0000 Subject: [PATCH 05/34] Add kernel TLS support 1. The kernel TLS implementation is based on https://github.com/jim3ma/go.git branch: dev.ktls.1.16.3 2. Supports: TLS1.2 & TLS 1.3 3. Supported cipher suites: AES_128_GCM_SHA256 AES_256_GCM_SHA384 CHACHA20_POLY1305_SHA256 4. Server side has been tested and it works. Client side needs to be tested later 5. TODO: add sendfile(), TLS_TX_ZEROCOPY_RO (device offload), and TLS_RX_EXPECT_NO_PAD. (See https://docs.kernel.org/networking/tls.html#optional-optimizations) for details. --- eventloop.go | 48 +-- pkg/tls/conn.go | 83 ++++- pkg/tls/handshake_client.go | 7 + pkg/tls/handshake_client_tls13.go | 5 +- pkg/tls/handshake_server.go | 6 + pkg/tls/handshake_server_tls13.go | 4 + pkg/tls/ktls.go | 17 + pkg/tls/ktls_cipher_linux.go | 367 ++++++++++++++++++++++ pkg/tls/ktls_io.go | 36 +++ pkg/tls/ktls_linux.go | 497 ++++++++++++++++++++++++++++++ pkg/tls/ktls_log_debug.go | 16 + pkg/tls/ktls_log_release.go | 8 + pkg/tls/ktls_others.go | 26 ++ 13 files changed, 1093 insertions(+), 27 deletions(-) create mode 100644 pkg/tls/ktls.go create mode 100644 pkg/tls/ktls_cipher_linux.go create mode 100644 pkg/tls/ktls_io.go create mode 100644 pkg/tls/ktls_linux.go create mode 100644 pkg/tls/ktls_log_debug.go create mode 100644 pkg/tls/ktls_log_release.go create mode 100644 pkg/tls/ktls_others.go diff --git a/eventloop.go b/eventloop.go index 70dc8123f..9da326656 100644 --- a/eventloop.go +++ b/eventloop.go @@ -114,7 +114,36 @@ func (el *eventloop) open(c *conn) error { return el.handleAction(c, action) } +func (el *eventloop) readTLS(c *conn) error { + if err := c.tlsconn.ReadFrame(); err != nil { + return el.closeConn(c, os.NewSyscallError("TLS read", err)) + } + + if c.inboundBuffer.IsEmpty() { + return nil + } + + action := el.eventHandler.OnTraffic(c) + switch action { + case None: + case Close: + return el.closeConn(c, nil) + case Shutdown: + return gerrors.ErrEngineShutdown + } + return nil +} + func (el *eventloop) read(c *conn) error { + // detected whether kernel TLS RX is enabled + // This only happens after TLS handshake is completed. + // Therefore, no need to call c.tlsconn.HandshakeComplete() + // In addition, all data are copied directly from kernel to the buffer, + // meaning no need to call unix.read(c.fd, el.buffer) + if c.tlsconn != nil && c.tlsconn.IsKTLSRXEnabled() { + return el.readTLS(c) + } + n, err := unix.Read(c.fd, el.buffer) if err != nil || n == 0 { if err == unix.EAGAIN { @@ -142,24 +171,7 @@ func (el *eventloop) read(c *conn) error { return nil } } - - if err = c.tlsconn.ReadFrame(); err != nil { - return el.closeConn(c, os.NewSyscallError("TLS read", err)) - } - - if c.inboundBuffer.IsEmpty() { - return nil - } - - action := el.eventHandler.OnTraffic(c) - switch action { - case None: - case Close: - return el.closeConn(c, nil) - case Shutdown: - return gerrors.ErrEngineShutdown - } - return nil + return el.readTLS(c) } c.buffer = el.buffer[:n] diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 46a2de1e1..20d822596 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -22,6 +22,12 @@ import ( "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" ) +// Socket is a set of functions which manipulate the underlying file descriptor of a connection. +type Socket interface { + // Fd returns the underlying file descriptor. + Fd() int +} + // A Conn represents a secured connection. // It implements the net.Conn interface. type Conn struct { @@ -183,6 +189,9 @@ type halfConn struct { nextMac hash.Hash // next MAC algorithm trafficSecret []byte // current TLS 1.3 traffic secret + + key []byte // encrypt or decrypt key for kernel tls + iv []byte // encrypt or decrypt iv for kernel tls } type permanentError struct { @@ -229,8 +238,8 @@ func (hc *halfConn) changeCipherSpec() error { func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { hc.trafficSecret = secret - key, iv := suite.trafficKey(secret) - hc.cipher = suite.aead(key, iv) + hc.key, hc.iv = suite.trafficKey(secret) + hc.cipher = suite.aead(hc.key, hc.iv) for i := range hc.seq { hc.seq[i] = 0 } @@ -270,6 +279,8 @@ func (hc *halfConn) explicitNonceLen() int { return c.BlockSize() } return 0 + case kTLSCipher: + return 0 default: panic("unknown cipher type") } @@ -600,6 +611,13 @@ func (c *Conn) readChangeCipherSpec() error { return io.EOF } +// ktlsInBufPool pools the buffers used by ktlsReadRecord. +var ktlsInBufPool = sync.Pool{ + New: func() any { + return new([maxPlaintext]byte) + }, +} + // readRecordOrCCS reads one or more TLS records from the connection and // updates the record layer state. Some invariants: // - c.in must be locked @@ -620,8 +638,40 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { } handshakeComplete := c.HandshakeComplete() - hdr := c.rawInput.Bytes() - typ := recordType(hdr[0]) + var ( + typ recordType + data []byte + // record []byte + hdr []byte + n int + vers uint16 + err error + ) + + if _, ok := c.in.cipher.(kTLSCipher); ok { + dataPtr := ktlsInBufPool.Get().(*[]byte) + data := *dataPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *dataPtr = data + ktlsInBufPool.Put(data) + }() + if typ, n, err = ktlsReadRecord(c.conn.(Socket).Fd(), data); err != nil { + return err + } + data = data[:n] + // TODO: process the data here instead of goto processMessage + // && try to use ktlsReadRecord to write data directly into input + // rather than copy it later. + goto processMessage + } + + hdr = c.rawInput.Bytes() + typ = recordType(hdr[0]) // No valid TLS record has a type of 0x80, however SSLv2 handshakes // start with a uint16 length where the MSB is set and the first record @@ -632,8 +682,8 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) } - vers := uint16(hdr[1])<<8 | uint16(hdr[2]) - n := int(hdr[3])<<8 | int(hdr[4]) + vers = uint16(hdr[1])<<8 | uint16(hdr[2]) + n = int(hdr[3])<<8 | int(hdr[4]) if len(hdr) < recordHeaderLen+n { return io.EOF @@ -662,10 +712,12 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { // Process message. c.rawInput.DiscardWithoutDone(recordHeaderLen + n) defer c.rawInput.DoneIfEmpty() - data, typ, err := c.in.decrypt(hdr[:recordHeaderLen+n]) + data, typ, err = c.in.decrypt(hdr[:recordHeaderLen+n]) if err != nil { return c.in.setErrorLocked(c.sendAlert(err.(alert))) } + +processMessage: if len(data) > maxPlaintext { return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) } @@ -850,6 +902,8 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { // The MAC is appended before padding so affects the // payload size directly. payloadBytes -= c.out.mac.Size() + case kTLSCipher: + payloadBytes -= kTLSOverhead default: panic("unknown cipher type") } @@ -900,6 +954,18 @@ var outBufPool = sync.Pool{ // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + if _, ok := c.out.cipher.(kTLSCipher); ok { + switch typ { + case recordTypeAlert: + return ktlsSendCtrlMessage(c.conn.(Socket).Fd(), typ, data) + case recordTypeHandshake, recordTypeChangeCipherSpec: + return ktlsSendCtrlMessage(c.conn.(Socket).Fd(), typ, data) + case recordTypeApplicationData: + return c.write(data) + default: + panic("unknown record type") + } + } outBufPtr := outBufPool.Get().(*[]byte) outBuf := *outBufPtr defer func() { @@ -1108,7 +1174,8 @@ func (c *Conn) RawWrite(data []byte) (int, error) { // Decrypt one tls record and save it in the 解析一条tls数据 func (c *Conn) ReadFrame() error { - if c.rawInput.Len() > recordHeaderLen { + _, ok := c.in.cipher.(kTLSCipher) + if c.rawInput.Len() > recordHeaderLen || ok { return c.readRecordOrCCS(false) } return io.EOF diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go index 669ab68cc..09ffefdf7 100644 --- a/pkg/tls/handshake_client.go +++ b/pkg/tls/handshake_client.go @@ -489,6 +489,11 @@ func (hs *clientHandshakeState) handshake() error { if hs.cacheKey != "" && hs.session != nil && hs.oldsession != hs.session { c.config.ClientSessionCache.Put(hs.cacheKey, hs.session) } + // Enable kernel TLS if possible + if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { + return err + } + return nil } @@ -708,6 +713,8 @@ func (hs *clientHandshakeState) establishKeys() error { serverCipher = hs.suite.aead(serverKey, serverIV) } + c.in.key, c.in.iv = serverKey, serverIV + c.out.key, c.out.iv = clientKey, clientIV c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) return nil diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go index 2b272c1c4..c9b11c819 100644 --- a/pkg/tls/handshake_client_tls13.go +++ b/pkg/tls/handshake_client_tls13.go @@ -105,7 +105,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } c.handshakeStatus = 255 - + // Enable kernel TLS if possible + if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { + return err + } return nil } diff --git a/pkg/tls/handshake_server.go b/pkg/tls/handshake_server.go index f4506ea63..361a9f6e9 100644 --- a/pkg/tls/handshake_server.go +++ b/pkg/tls/handshake_server.go @@ -159,6 +159,10 @@ func (hs *serverHandshakeState) handshake() error { c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) c.handshakeStatus = 255 + // Enable kernel TLS if possible + if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { + return err + } return nil } @@ -739,6 +743,8 @@ func (hs *serverHandshakeState) establishKeys() error { serverCipher = hs.suite.aead(serverKey, serverIV) } + c.in.key, c.in.iv = clientKey, clientIV + c.out.key, c.out.iv = serverKey, serverIV c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) diff --git a/pkg/tls/handshake_server_tls13.go b/pkg/tls/handshake_server_tls13.go index ee527036e..860518ecb 100644 --- a/pkg/tls/handshake_server_tls13.go +++ b/pkg/tls/handshake_server_tls13.go @@ -86,6 +86,10 @@ func (hs *serverHandshakeStateTLS13) handshake() error { return err } c.handshakeStatus = 255 + // Enable kernel TLS if possible + if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { + return err + } } return nil diff --git a/pkg/tls/ktls.go b/pkg/tls/ktls.go new file mode 100644 index 000000000..69e4ba647 --- /dev/null +++ b/pkg/tls/ktls.go @@ -0,0 +1,17 @@ +package tls + +import ( + "os" + "strings" +) + +var kTLSEnabled bool + +// kTLSCipher is a placeholder to tell the record layer to skip wrapping. +type kTLSCipher struct{} + +func init() { + kTLSEnabled = strings.ToLower(os.Getenv("GOKTLS")) == "true" || + strings.ToLower(os.Getenv("GOKTLS")) == "on" || + os.Getenv("GOKTLS") == "1" +} diff --git a/pkg/tls/ktls_cipher_linux.go b/pkg/tls/ktls_cipher_linux.go new file mode 100644 index 000000000..2517177c2 --- /dev/null +++ b/pkg/tls/ktls_cipher_linux.go @@ -0,0 +1,367 @@ +//go:build linux +// +build linux + +package tls + +import ( + "fmt" + "syscall" + "unsafe" +) + +const ( + kTLS_CIPHER_AES_GCM_128 = 51 + kTLS_CIPHER_AES_GCM_128_IV_SIZE = 8 + kTLS_CIPHER_AES_GCM_128_KEY_SIZE = 16 + kTLS_CIPHER_AES_GCM_128_SALT_SIZE = 4 + kTLS_CIPHER_AES_GCM_128_TAG_SIZE = 16 + kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE = 8 + + kTLS_CIPHER_AES_GCM_256 = 52 + kTLS_CIPHER_AES_GCM_256_IV_SIZE = 8 + kTLS_CIPHER_AES_GCM_256_KEY_SIZE = 32 + kTLS_CIPHER_AES_GCM_256_SALT_SIZE = 4 + kTLS_CIPHER_AES_GCM_256_TAG_SIZE = 16 + kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE = 8 + + // AES_CCM_128 is not used as it has not been implemented in golang + kTLS_CIPHER_AES_CCM_128 = 53 + kTLS_CIPHER_AES_CCM_128_IV_SIZE = 8 + kTLS_CIPHER_AES_CCM_128_KEY_SIZE = 16 + kTLS_CIPHER_AES_CCM_128_SALT_SIZE = 4 + kTLS_CIPHER_AES_CCM_128_TAG_SIZE = 16 + kTLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE = 8 + + kTLS_CIPHER_CHACHA20_POLY1305 = 54 + kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE = 12 + kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE = 32 + kTLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE = 0 + kTLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE = 16 + kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE = 8 +) + +type kTLSCryptoInfo struct { + version uint16 + cipherType uint16 +} + +type kTLSCryptoInfoAESGCM128 struct { + info kTLSCryptoInfo + iv [kTLS_CIPHER_AES_GCM_128_IV_SIZE]byte + key [kTLS_CIPHER_AES_GCM_128_KEY_SIZE]byte + salt [kTLS_CIPHER_AES_GCM_128_SALT_SIZE]byte + recSeq [kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE]byte +} + +type kTLSCryptoInfoAESGCM256 struct { + info kTLSCryptoInfo + iv [kTLS_CIPHER_AES_GCM_256_IV_SIZE]byte + key [kTLS_CIPHER_AES_GCM_256_KEY_SIZE]byte + salt [kTLS_CIPHER_AES_GCM_256_SALT_SIZE]byte + recSeq [kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE]byte +} + +// AES_CCM_128 is not used as it has not been implemented in golang +type kTLSCryptoInfoAESCCM128 struct { + info kTLSCryptoInfo + iv [kTLS_CIPHER_AES_CCM_128_IV_SIZE]byte + key [kTLS_CIPHER_AES_CCM_128_KEY_SIZE]byte + salt [kTLS_CIPHER_AES_CCM_128_SALT_SIZE]byte + recSeq [kTLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE]byte +} + +type kTLSCryptoInfoCHACHA20POLY1305 struct { + info kTLSCryptoInfo + iv [kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE]byte + key [kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE]byte + salt [kTLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE]byte + recSeq [kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE]byte +} + +const ( + kTLSCryptoInfoSize_AES_GCM_128 = 2 + 2 + kTLS_CIPHER_AES_GCM_128_IV_SIZE + kTLS_CIPHER_AES_GCM_128_KEY_SIZE + + kTLS_CIPHER_AES_GCM_128_SALT_SIZE + kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE + + kTLSCryptoInfoSize_AES_GCM_256 = 2 + 2 + kTLS_CIPHER_AES_GCM_256_IV_SIZE + kTLS_CIPHER_AES_GCM_256_KEY_SIZE + + kTLS_CIPHER_AES_GCM_256_SALT_SIZE + kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE + + kTLSCryptoInfoSize_AES_CCM_128 = 2 + 2 + kTLS_CIPHER_AES_CCM_128_IV_SIZE + kTLS_CIPHER_AES_CCM_128_KEY_SIZE + + kTLS_CIPHER_AES_CCM_128_SALT_SIZE + kTLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE + + kTLSCryptoInfoSize_CHACHA20_POLY1305 = 2 + 2 + kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE + kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE + + kTLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE + kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE +) + +func ktlsEnableAES( + c *Conn, + version uint16, + enableFunc func(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) error, + keyLen int, + inKey, outKey, inIV, outIV []byte) error { + var ulpEnabled bool + + // Try to enable Kernel TLS TX + if !kTLSSupportTX { + return nil + } + if len(outKey) == keyLen { + if sock, ok := c.conn.(Socket); ok { + if err := enableFunc(sock.Fd(), version, TLS_TX, ulpEnabled, outKey, outIV[:], c.out.seq[:]); err != nil { + Debugln("kTLS: TLS_TX error enabling:", err) + return err + } + ulpEnabled = true + Debugln("kTLS: TLS_TX enabled") + c.out.cipher = kTLSCipher{} + } else { + Debugln("kTLS: TLS_TX unsupported connection type") + } + } else { + Debugln("kTLS: TLS_TX unsupported key length") + } + + // Try to enable Kernel TLS RX + if !kTLSSupportRX { + return nil + } + if len(inKey) == keyLen { + if sock, ok := c.conn.(Socket); ok { + if err := enableFunc(sock.Fd(), version, TLS_RX, ulpEnabled, inKey, inIV[:], c.in.seq[:]); err != nil { + Debugln("kTLS: TLS_RX error enabling:", err) + return err + } + Debugln("kTLS: TLS_RX enabled") + c.in.cipher = kTLSCipher{} + } else { + Debugln("kTLS: TLS_RX unsupported connection type") + } + } else { + Debugln("kTLS: TLS_TX unsupported key length") + } + + return nil +} + +func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []byte) error { + var ulpEnabled bool + + // Try to enable Kernel TLS TX + if !kTLSSupportTX { + return nil + } + if sock, ok := c.conn.(Socket); ok { + err := ktlsEnableCHACHA20POLY1305(sock.Fd(), version, TLS_TX, ulpEnabled, outKey, outIV, c.out.seq[:]) + if err != nil { + Debugln("kTLS: TLS_TX error enabling:", err) + return err + } + ulpEnabled = true + Debugln("kTLS: TLS_TX enabled") + c.out.cipher = kTLSCipher{} + } else { + Debugln("kTLS: TLS_TX unsupported connection type") + } + + // Try to enable Kernel TLS RX + if !kTLSSupportRX { + return nil + } + if sock, ok := c.conn.(Socket); ok { + err := ktlsEnableCHACHA20POLY1305(sock.Fd(), version, TLS_RX, ulpEnabled, inKey[:], inIV[:], c.in.seq[:]) + if err != nil { + Debugln("kTLS: TLS_RX error enabling:", err) + return err + } + ulpEnabled = true + Debugln("kTLS: TLS_RX enabled") + c.in.cipher = kTLSCipher{} + } else { + Debugln("kTLS: TLS_RX unsupported connection type") + } + + return nil +} + +func ktlsEnableAES128GCM(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) (err error) { + if len(key) != kTLS_CIPHER_AES_GCM_128_KEY_SIZE { + return fmt.Errorf("kTLS: wrong key length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_128_KEY_SIZE, len(key)) + } + if version == VersionTLS12 { + // The nounce of TLS 1.2 only has 4 bytes. So, compare with kTLS_CIPHER_AES_GCM_128_SALT_SIZE only + if len(iv) != kTLS_CIPHER_AES_GCM_128_SALT_SIZE { + return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_128_SALT_SIZE, len(iv)) + } + if len(seq) != kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE { + return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE, len(seq)) + } + } else { + // The nounce of TLS 1.3 only has 12 bytes. So, compare with + // kTLS_CIPHER_AES_GCM_128_SALT_SIZE + kTLS_CIPHER_AES_GCM_128_IV_SIZE + if len(iv) != kTLS_CIPHER_AES_GCM_128_SALT_SIZE+kTLS_CIPHER_AES_GCM_128_IV_SIZE { + return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_128_SALT_SIZE+kTLS_CIPHER_AES_GCM_128_IV_SIZE, len(iv)) + } + if len(seq) != kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE { + return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE, len(seq)) + } + } + + cryptoInfo := kTLSCryptoInfoAESGCM128{ + info: kTLSCryptoInfo{ + version: version, + cipherType: kTLS_CIPHER_AES_GCM_128, + }, + } + + Debugf("\nkey: %x\niv: %x\nseq: %x", key, iv, seq) + copy(cryptoInfo.key[:], key) + copy(cryptoInfo.salt[:], iv[:kTLS_CIPHER_AES_GCM_128_SALT_SIZE]) + // TODO https://github.com/FiloSottile/go/blob/filippo%2FkTLS/src/crypto/tls/ktls.go#L73 + // the PoC of FiloSottile here is copy(cryptoInfo.iv[:], seq) + // For TLS 1.2, its IV is 0, whereas TLS 1.3 uses the rest of 8 bytes + copy(cryptoInfo.iv[:], iv[kTLS_CIPHER_AES_GCM_128_SALT_SIZE:]) + copy(cryptoInfo.recSeq[:], seq) + + // Assert padding isn't introduced by alignment requirements. + if unsafe.Sizeof(cryptoInfo) != kTLSCryptoInfoSize_AES_GCM_128 { + return fmt.Errorf("kTLS: wrong cryptoInfo size, desired: %d, actual: %d", + kTLSCryptoInfoSize_AES_GCM_128, unsafe.Sizeof(cryptoInfo)) + } + + if !skip { + err = syscall.SetsockoptString(int(fd), syscall.SOL_TCP, TCP_ULP, "tls") + if err != nil { + Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) + return + } + } + err = syscall.SetsockoptString(int(fd), SOL_TLS, opt, + string((*[kTLSCryptoInfoSize_AES_GCM_128]byte)(unsafe.Pointer(&cryptoInfo))[:])) + if err != nil { + Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) + return + } + + return err +} + +func ktlsEnableAES256GCM(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) (err error) { + if len(key) != kTLS_CIPHER_AES_GCM_256_KEY_SIZE { + return fmt.Errorf("kTLS: wrong key length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_256_KEY_SIZE, len(key)) + } + if version == VersionTLS12 { + // The nounce of TLS 1.2 only has 4 bytes. So, compare with kTLS_CIPHER_AES_GCM_256_SALT_SIZE only + if len(iv) != kTLS_CIPHER_AES_GCM_256_SALT_SIZE { + return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_256_SALT_SIZE, len(iv)) + } + if len(seq) != kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE { + return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE, len(seq)) + } + } else { + // The nounce of TLS 1.3 only has 12 bytes. So, compare with + // kTLS_CIPHER_AES_GCM_256_SALT_SIZE + kTLS_CIPHER_AES_GCM_256_IV_SIZE + if len(iv) != kTLS_CIPHER_AES_GCM_256_SALT_SIZE+kTLS_CIPHER_AES_GCM_256_IV_SIZE { + return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_256_SALT_SIZE+kTLS_CIPHER_AES_GCM_256_IV_SIZE, len(iv)) + } + if len(seq) != kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE { + return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", + kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE, len(seq)) + } + } + + cryptoInfo := kTLSCryptoInfoAESGCM256{ + info: kTLSCryptoInfo{ + version: version, + cipherType: kTLS_CIPHER_AES_GCM_256, + }, + } + + Debugf("key: %x\niv: %x\n seq: %x", key, iv, seq) + copy(cryptoInfo.key[:], key) + copy(cryptoInfo.salt[:], iv[:kTLS_CIPHER_AES_GCM_256_SALT_SIZE]) + // TODO https://github.com/FiloSottile/go/blob/filippo%2FkTLS/src/crypto/tls/ktls.go#L73 + // the PoC of FiloSottile here is copy(cryptoInfo.iv[:], seq) + // For TLS 1.2, its IV is 0, whereas TLS 1.3 uses the rest of 8 bytes + copy(cryptoInfo.iv[:], iv[kTLS_CIPHER_AES_GCM_256_SALT_SIZE:]) + copy(cryptoInfo.recSeq[:], seq) + + // Assert padding isn't introduced by alignment requirements. + if unsafe.Sizeof(cryptoInfo) != kTLSCryptoInfoSize_AES_GCM_256 { + return fmt.Errorf("kTLS: wrong cryptoInfo size, desired: %d, actual: %d", + kTLSCryptoInfoSize_AES_GCM_256, unsafe.Sizeof(cryptoInfo)) + } + + if !skip { + err = syscall.SetsockoptString(int(fd), syscall.SOL_TCP, TCP_ULP, "tls") + if err != nil { + Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) + return + } + } + err = syscall.SetsockoptString(int(fd), SOL_TLS, opt, + string((*[kTLSCryptoInfoSize_AES_GCM_256]byte)(unsafe.Pointer(&cryptoInfo))[:])) + if err != nil { + Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) + return + } + + return err +} + +func ktlsEnableCHACHA20POLY1305(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) (err error) { + if len(key) != kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE { + return fmt.Errorf("kTLS: wrong key length, desired: %d, actual: %d", + kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE, len(key)) + } + if len(iv) != kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE { + return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", + kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE, len(iv)) + } + if len(seq) != kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE { + return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", + kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE, len(seq)) + } + + cryptoInfo := kTLSCryptoInfoCHACHA20POLY1305{ + info: kTLSCryptoInfo{ + version: version, + cipherType: kTLS_CIPHER_CHACHA20_POLY1305, + }, + } + + Debugf("\nkey: %x\niv: %x\nseq: %x", key, iv, seq) + copy(cryptoInfo.key[:], key) + copy(cryptoInfo.iv[:], iv) + // the salt of CHACHA20POLY1305 is 0 bytes. So, no need to copy + copy(cryptoInfo.recSeq[:], seq) + + // Assert padding isn't introduced by alignment requirements. + if unsafe.Sizeof(cryptoInfo) != kTLSCryptoInfoSize_CHACHA20_POLY1305 { + return fmt.Errorf("kTLS: wrong cryptoInfo size, desired: %d, actual: %d", + kTLSCryptoInfoSize_CHACHA20_POLY1305, unsafe.Sizeof(cryptoInfo)) + } + + if !skip { + err = syscall.SetsockoptString(int(fd), syscall.SOL_TCP, TCP_ULP, "tls") + if err != nil { + Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) + return + } + } + err = syscall.SetsockoptString(int(fd), SOL_TLS, opt, + string((*[kTLSCryptoInfoSize_CHACHA20_POLY1305]byte)(unsafe.Pointer(&cryptoInfo))[:])) + if err != nil { + Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) + return + } + + return err +} diff --git a/pkg/tls/ktls_io.go b/pkg/tls/ktls_io.go new file mode 100644 index 000000000..b460c9296 --- /dev/null +++ b/pkg/tls/ktls_io.go @@ -0,0 +1,36 @@ +package tls + +import "io" + +// LimitWriter is a copy of the standard library ioutils.LimitReader, +// applied to the writer interface. +// LimitWriter returns a Writer that writes to w +// but stops with EOF after n bytes. +// The underlying implementation is a *LimitedWriter. +func LimitWriter(w io.Writer, n int64) io.Writer { return &LimitedWriter{w, n} } + +// A LimitedWriter writes to W but limits the amount of +// data returned to just N bytes. Each call to Write +// updates N to reflect the new amount remaining. +// Write returns EOF when N <= 0 or when the underlying W returns EOF. +type LimitedWriter struct { + W io.Writer // underlying writer + N int64 // max bytes remaining +} + +func (l *LimitedWriter) Write(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, io.ErrShortWrite + } + truncated := false + if int64(len(p)) > l.N { + p = p[0:l.N] + truncated = true + } + n, err = l.W.Write(p) + l.N -= int64(n) + if err == nil && truncated { + err = io.ErrShortWrite + } + return +} \ No newline at end of file diff --git a/pkg/tls/ktls_linux.go b/pkg/tls/ktls_linux.go new file mode 100644 index 000000000..61c0a758a --- /dev/null +++ b/pkg/tls/ktls_linux.go @@ -0,0 +1,497 @@ +//go:build linux +// +build linux + +package tls + +import ( + "fmt" + "io" + "net" + "os" + "strconv" + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + TCP_ULP = 31 + SOL_TLS = 282 + + TLS_TX = 1 + TLS_RX = 2 + + TLS_SET_RECORD_TYPE = 1 + TLS_GET_RECORD_TYPE = 2 + + kTLSOverhead = 16 +) + +var ( + kTLSSupport bool + + // kTLSSupportTX is true when kTLSSupport is true + kTLSSupportTX bool + kTLSSupportRX bool + + // kTLSSupportAESGCM128 is true when kTLSSupport is true + kTLSSupportAESGCM128 bool + kTLSSupportAESGCM256 bool + kTLSSupportCHACHA20POLY1305 bool + + kTLSSupportTLS13 bool +) + +func init() { + // when kernel tls module enabled, /sys/module/tls is available + _, err := os.Stat("/sys/module/tls") + if err != nil { + Debugln("kTLS: kernel tls module not enabled") + return + } + kTLSSupport = true && kTLSEnabled + Debugf("kTLS Enabled Status: %v\n", kTLSSupport) + + var uname syscall.Utsname + if err := syscall.Uname(&uname); err != nil { + Debugf("kTLS: call uname failed %v", err) + return + } + + var buf [65]byte + for i, b := range uname.Release { + buf[i] = byte(b) + } + release := string(buf[:]) + if i := strings.Index(release, "\x00"); i != -1 { + release = release[:i] + } + majorRelease := release[:strings.Index(release, ".")] + minorRelease := strings.TrimLeft(release, majorRelease+".") + minorRelease = minorRelease[:strings.Index(minorRelease, ".")] + major, err := strconv.Atoi(majorRelease) + if err != nil { + Debugf("kTLS: parse major release failed %v", err) + return + } + minor, err := strconv.Atoi(minorRelease) + if err != nil { + Debugf("kTLS: parse minor release failed %v", err) + return + } + + if (major == 4 && minor >= 13) || major > 4 { + kTLSSupportTX = true + kTLSSupportAESGCM128 = true + } + + if (major == 4 && minor >= 17) || major > 4 { + kTLSSupportRX = true + } + + if (major == 5 && minor >= 1) || major > 5 { + kTLSSupportAESGCM256 = true + kTLSSupportTLS13 = true + } + + if (major == 5 && minor >= 11) || major > 5 { + kTLSSupportCHACHA20POLY1305 = true + } +} + +func (c *Conn) ReadFrom(r io.Reader) (n int64, err error) { + if err := c.Handshake(); err != nil { + return 0, err + } + return io.Copy(c.conn, r) +} + +const maxBufferSize int64 = 4 * 1024 * 1024 + +func (c *Conn) writeToFile(f *os.File, remain int64) (written int64, err error, handled bool) { + if remain <= 0 { + return 0, nil, false + } + offset, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return 0, nil, false + } + fi, err := f.Stat() + if err != nil { + return 0, nil, false + } + if offset+remain > fi.Size() { + err = f.Truncate(offset + remain) + if err != nil { + Debugf("file truncate error: %s", err) + return 0, nil, false + } + } + + // mmap must align on a page boundary + // mmap from 0, use data from offset + bytes, err := syscall.Mmap(int(f.Fd()), 0, int(offset+remain), + syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + return 0, nil, false + } + defer syscall.Munmap(bytes) + + bytes = bytes[offset : offset+remain] + var ( + start = int64(0) + end = maxBufferSize + ) + + for { + if end > remain { + end = remain + } + //now := time.Now() + n, err := c.Read(bytes[start:end]) + if err != nil { + return start + int64(n), err, true + } + //log.Printf("read %d bytes, cost %dus", n, time.Since(now).Microseconds()) + start += int64(n) + if start >= remain { + break + } + + end += int64(n) + } + return remain, nil, true +} + +var maxSpliceSize int64 = 4 << 20 + +func (c *Conn) spliceToFile(f *os.File, remain int64) (written int64, err error, handled bool) { + tcpConn, ok := c.conn.(*net.TCPConn) + if !ok { + return 0, nil, false + } + sc, err := tcpConn.SyscallConn() + if err != nil { + return 0, nil, false + } + fsc, err := f.SyscallConn() + if err != nil { + return 0, nil, false + } + + var pipes [2]int + if err := unix.Pipe(pipes[:]); err != nil { + return 0, nil, false + } + + prfd, pwfd := pipes[0], pipes[1] + defer destroyTempPipe(prfd, pwfd) + + var ( + n = maxSpliceSize + m int64 + ) + + rerr := sc.Read(func(rfd uintptr) (done bool) { + for { + n = maxSpliceSize + if n > remain { + n = remain + } + // move tcp data to pipe + // FIXME should not use unix.SPLICE_F_NONBLOCK, when use this flag, ktls will not advance socket buffer + // refer: https://github.com/torvalds/linux/blob/v5.12/net/tls/tls_sw.c#L2021 + n, err = unix.Splice(int(rfd), nil, pwfd, nil, int(n), unix.SPLICE_F_MORE) + remain -= n + written += n + if err == syscall.EAGAIN { + // return false to wait data from connection + err = nil + return false + } + + if err != nil { + break + } + + // move pipe data to file + werr := fsc.Write(func(wfd uintptr) (done bool) { + bump: + m, err = unix.Splice(prfd, nil, int(wfd), nil, int(n), + unix.SPLICE_F_MOVE|unix.SPLICE_F_MORE|unix.SPLICE_F_NONBLOCK) + if err != nil { + return true + } + if m < n { + n -= m + goto bump + } + return true + }) + if err == nil { + err = werr + } + if err != nil || remain <= 0 { + break + } + } + return true + }) + if err == nil { + err = rerr + } + return written, err, true +} + +// destroyTempPipe destroys a temporary pipe. +func destroyTempPipe(prfd, pwfd int) error { + err := syscall.Close(prfd) + err1 := syscall.Close(pwfd) + if err == nil { + return err1 + } + return err +} + +func (c *Conn) WriteTo(w io.Writer) (n int64, err error) { + if err := c.Handshake(); err != nil { + return 0, err + } + + if lw, ok := w.(*LimitedWriter); ok { + if f, ok := lw.W.(*os.File); ok { + n, err, handled := c.spliceToFile(f, lw.N) + if handled { + return n, err + } + } + } + + // FIXME read at least one record for io.EOF and so on ? + //if conn, ok := w.(*net.TCPConn); ok { + // buf := make([]byte, 16*1024) + // n, err := ktlsReadRecord(conn, buf) + // if err != nil { + // wn, _ := w.Write(buf[:n]) + // return int64(wn), err + // } + // wn, err := w.Write(buf[:n]) + // if err != nil { + // return int64(wn), err + // } + //} + return io.Copy(w, c.conn) +} + +func (c *Conn) IsKTLSTXEnabled() bool { + _, ok := c.out.cipher.(kTLSCipher) + return ok +} + +func (c *Conn) IsKTLSRXEnabled() bool { + _, ok := c.in.cipher.(kTLSCipher) + return ok +} + +func (c *Conn) enableKernelTLS(cipherSuiteID uint16, inKey, outKey, inIV, outIV []byte) error { + if !kTLSSupport { + return nil + } + switch cipherSuiteID { + // Kernel TLS 1.2 + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_RSA_WITH_AES_128_GCM_SHA256: + if !kTLSSupportAESGCM128 { + return nil + } + Debugln("try to enable kernel tls AES_128_GCM") + return ktlsEnableAES(c, VersionTLS12, ktlsEnableAES128GCM, 16, inKey, outKey, inIV, outIV) + case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_RSA_WITH_AES_256_GCM_SHA384: + if !kTLSSupportAESGCM256 { + return nil + } + Debugln("try to enable kernel tls AES_256_GCM") + return ktlsEnableAES(c, VersionTLS12, ktlsEnableAES256GCM, 32, inKey, outKey, inIV, outIV) + case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: + if !kTLSSupportCHACHA20POLY1305 { + return nil + } + Debugln("try to enable kernel tls CHACHA20_POLY1305") + return ktlsEnableCHACHA20(c, VersionTLS12, inKey, outKey, inIV, outIV) + + // Kernel TLS 1.3 + case TLS_AES_128_GCM_SHA256: + if !kTLSSupportAESGCM128 { + return nil + } + Debugln("try to enable kernel tls AES_128_GCM for tls 1.3") + return ktlsEnableAES(c, VersionTLS13, ktlsEnableAES128GCM, 16, inKey, outKey, inIV, outIV) + case TLS_AES_256_GCM_SHA384: + if !kTLSSupportAESGCM256 { + return nil + } + Debugln("try to enable kernel tls AES_256_GCM tls 1.3") + return ktlsEnableAES(c, VersionTLS13, ktlsEnableAES256GCM, 32, inKey, outKey, inIV, outIV) + case TLS_CHACHA20_POLY1305_SHA256: + if !kTLSSupportCHACHA20POLY1305 { + return nil + } + Debugln("try to enable kernel tls CHACHA20_POLY1305 for tls 1.3") + return ktlsEnableCHACHA20(c, VersionTLS13, inKey, outKey, inIV, outIV) + } + return nil +} + +func ktlsReadRecord(fd int, b []byte) (recordType, int, error) { + // cmsg for record type + buffer := make([]byte, syscall.CmsgSpace(1)) + cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&buffer[0])) + cmsg.SetLen(syscall.CmsgLen(1)) + + var iov syscall.Iovec + iov.Base = &b[0] + iov.SetLen(len(b)) + + var msg syscall.Msghdr + msg.Control = &buffer[0] + msg.Controllen = cmsg.Len + msg.Iov = &iov + msg.Iovlen = 1 + + var n int + flags := 0 + n, err := recvmsg(uintptr(fd), &msg, flags) + if err == syscall.EAGAIN { + // data is not ready, goroutine will be parked + return 0, n, err + } + // n should not be zero when err == nil + if err == nil && n == 0 { + err = io.EOF + } + + if err != nil { + Debugln("kTLS: recvmsg failed:", err) + // fix bufio panic due to n == -1 + if n == -1 { + n = 0 + } + return 0, n, err + } + + if n < 0 { + return 0, 0, fmt.Errorf("unknown size received: %d", n) + } else if n == 0 { + return 0, 0, nil + } + + if cmsg.Level != SOL_TLS { + Debugf("kTLS: unsupported cmsg level: %d", cmsg.Level) + return 0, 0, fmt.Errorf("unsupported cmsg level: %d", cmsg.Level) + } + if cmsg.Type != TLS_GET_RECORD_TYPE { + Debugf("kTLS: unsupported cmsg type: %d", cmsg.Type) + return 0, 0, fmt.Errorf("unsupported cmsg type: %d", cmsg.Type) + } + typ := recordType(buffer[syscall.SizeofCmsghdr]) + Debugf("kTLS: recvmsg, type: %d, payload len: %d", typ, n) + return typ, n, nil +} + +func ktlsReadDataFromRecord(fd int, b []byte) (int, error) { + typ, n, err := ktlsReadRecord(fd, b) + if err != nil { + return n, err + } + switch typ { + case recordTypeAlert: + if n < 2 { + return 0, fmt.Errorf("ktls alert payload too short") + } + if alert(b[1]) == alertCloseNotify { + return 0, io.EOF + } + return 0, fmt.Errorf("unsupported ktls alert type: %d", b[0]) + case recordTypeApplicationData: + return n, nil + default: + return 0, fmt.Errorf("unsupported ktls record type: %d", typ) + } +} + +func recvmsg(fd uintptr, msg *syscall.Msghdr, flags int) (n int, err error) { + r0, _, e1 := syscall.Syscall(syscall.SYS_RECVMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) + n = int(r0) + if e1 != 0 { + err = errnoErr(e1) + } + return +} + +func sendmsg(fd uintptr, msg *syscall.Msghdr, flags int) (n int, err error) { + r0, _, e1 := syscall.Syscall(syscall.SYS_SENDMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) + n = int(r0) + if e1 != 0 { + err = errnoErr(e1) + } + return +} + +// Do the interface allocations only once for common +// Errno values. +var ( + errEAGAIN error = syscall.EAGAIN + errEINVAL error = syscall.EINVAL + errENOENT error = syscall.ENOENT +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case syscall.EAGAIN: + return errEAGAIN + case syscall.EINVAL: + return errEINVAL + case syscall.ENOENT: + return errENOENT + } + return e +} + +func ktlsSendCtrlMessage(fd int, typ recordType, b []byte) (int, error) { + // cmsg for record type + buffer := make([]byte, syscall.CmsgSpace(1)) + cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&buffer[0])) + cmsg.SetLen(syscall.CmsgLen(1)) + buffer[syscall.SizeofCmsghdr] = byte(typ) + cmsg.Level = SOL_TLS + cmsg.Type = TLS_SET_RECORD_TYPE + + var iov syscall.Iovec + iov.Base = &b[0] + iov.SetLen(len(b)) + + var msg syscall.Msghdr + msg.Control = &buffer[0] + msg.Controllen = cmsg.Len + msg.Iov = &iov + msg.Iovlen = 1 + + var n int + flags := 0 + n, err := sendmsg(uintptr(fd), &msg, flags) + if err == syscall.EAGAIN { + // data is not ready, goroutine will be parked + return n, err + } + if err != nil { + Debugln("kTLS: sendmsg failed:", err) + } + + Debugf("kTLS: sendmsg, type: %d, payload len: %d", typ, len(b)) + return n, err +} diff --git a/pkg/tls/ktls_log_debug.go b/pkg/tls/ktls_log_debug.go new file mode 100644 index 000000000..da1352c29 --- /dev/null +++ b/pkg/tls/ktls_log_debug.go @@ -0,0 +1,16 @@ +//go:build debug +package tls + +import ( + "log" +) + +const Dev = true + +func Debugln(a ...interface{}) { + log.Println(a...) +} + +func Debugf(format string, a ...interface{}) { + log.Printf(format, a...) +} \ No newline at end of file diff --git a/pkg/tls/ktls_log_release.go b/pkg/tls/ktls_log_release.go new file mode 100644 index 000000000..f5f5149cb --- /dev/null +++ b/pkg/tls/ktls_log_release.go @@ -0,0 +1,8 @@ +//go:build !debug +package tls + +const Dev = false + +func Debugln(a ...interface{}) {} + +func Debugf(format string, a ...interface{}) {} \ No newline at end of file diff --git a/pkg/tls/ktls_others.go b/pkg/tls/ktls_others.go new file mode 100644 index 000000000..45ac05c07 --- /dev/null +++ b/pkg/tls/ktls_others.go @@ -0,0 +1,26 @@ +//go:build !linux +// +build !linux + +package tls + +import ( + "net" +) + +const kTLSOverhead = 0 + +func (c *Conn) enableKernelTLS(cipherSuiteID uint16, inKey, outKey, inIV, outIV []byte) error { + return nil +} + +func ktlsSendCtrlMessage(fd int, typ recordType, b []byte) (int, error) { + panic("not implement") +} + +func ktlsReadDataFromRecord(fd int, b []byte) (int, error) { + panic("not implement") +} + +func ktlsReadRecord(fd int, b []byte) (recordType, int, error) { + panic("not implement") +} From 40e9536c859a9fb2f77be88fcee9d1d3b851c865 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 23 Jan 2023 16:55:06 +0000 Subject: [PATCH 06/34] Fix typos --- eventloop.go | 2 +- pkg/tls/handshake_client.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/eventloop.go b/eventloop.go index 9da326656..e7141397e 100644 --- a/eventloop.go +++ b/eventloop.go @@ -240,7 +240,7 @@ func (el *eventloop) closeConn(c *conn, err error) (rerr error) { return } - // clost the TLS connection by sending the alert + // close the TLS connection by sending the alert if c.tlsconn != nil { c.tlsconn.Close() } diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go index 09ffefdf7..171eaccd9 100644 --- a/pkg/tls/handshake_client.go +++ b/pkg/tls/handshake_client.go @@ -33,7 +33,7 @@ type clientHandshakeState struct { finishedHash finishedHash masterSecret []byte session *ClientSessionState - oldsession *ClientSessionState + oldSession *ClientSessionState cacheKey string } @@ -240,7 +240,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { serverHello: serverHello, hello: hello, session: hs13.session, - oldsession: hs13.session, + oldSession: hs13.session, cacheKey: hs13.cacheKey, } c.hs = hs @@ -486,7 +486,7 @@ func (hs *clientHandshakeState) handshake() error { // If we had a successful handshake and hs.session is different from // the one already cached - cache a new one. - if hs.cacheKey != "" && hs.session != nil && hs.oldsession != hs.session { + if hs.cacheKey != "" && hs.session != nil && hs.oldSession != hs.session { c.config.ClientSessionCache.Put(hs.cacheKey, hs.session) } // Enable kernel TLS if possible From c7d0993df3716e91b6554bbe699e66ff289f024f Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 23 Jan 2023 19:55:23 +0000 Subject: [PATCH 07/34] bug: fix type not matching in ktlsInBufPool.Get and Put --- pkg/tls/conn.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 20d822596..bf8bd5d36 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -614,7 +614,8 @@ func (c *Conn) readChangeCipherSpec() error { // ktlsInBufPool pools the buffers used by ktlsReadRecord. var ktlsInBufPool = sync.Pool{ New: func() any { - return new([maxPlaintext]byte) + buf := make([]byte, maxPlaintext) + return &buf }, } @@ -657,8 +658,8 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { // to the heap, causing an allocation. Instead, we keep around the // pointer to the slice header returned by Get, which is already on the // heap, and overwrite and return that. - *dataPtr = data - ktlsInBufPool.Put(data) + *dataPtr = data[:maxPlaintext] + ktlsInBufPool.Put(dataPtr) }() if typ, n, err = ktlsReadRecord(c.conn.(Socket).Fd(), data); err != nil { return err From 582f14672e4f47051949604d293fadd7e5c9cabb Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 23 Jan 2023 20:04:33 +0000 Subject: [PATCH 08/34] Add supports to TLS_TX_ZEROCOPY_RO and TLS_RX_EXPECT_NO_PAD, but not tested yet --- pkg/tls/ktls_cipher_linux.go | 42 ++++++++++++++++++++++++++++++++++++ pkg/tls/ktls_linux.go | 6 ++++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/pkg/tls/ktls_cipher_linux.go b/pkg/tls/ktls_cipher_linux.go index 2517177c2..ade07136d 100644 --- a/pkg/tls/ktls_cipher_linux.go +++ b/pkg/tls/ktls_cipher_linux.go @@ -113,6 +113,10 @@ func ktlsEnableAES( ulpEnabled = true Debugln("kTLS: TLS_TX enabled") c.out.cipher = kTLSCipher{} + // Try to enable kTLS TX zerocopy sendfile. + // Only enabled if the hardware supports the protocol. + // Otherwise, get an error message which is fine. + ktlsEnableTxZerocopySendfile(sock.Fd()) } else { Debugln("kTLS: TLS_TX unsupported connection type") } @@ -132,6 +136,13 @@ func ktlsEnableAES( } Debugln("kTLS: TLS_RX enabled") c.in.cipher = kTLSCipher{} + // Only enable the TLS_RX_EXPECT_NO_PAD for TLS 1.3 + // TODO: safe to enable only if the remote end is trusted, otherwise + // it is an attack vector to doubling the TLS processing cost. + // See: https://docs.kernel.org/networking/tls.html#tls-rx-expect-no-pad + // if version == VersionTLS13 { + // ktlsEnableRxExpectNoPad(sock.Fd()) + // } } else { Debugln("kTLS: TLS_RX unsupported connection type") } @@ -158,6 +169,10 @@ func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []by ulpEnabled = true Debugln("kTLS: TLS_TX enabled") c.out.cipher = kTLSCipher{} + // Try to enable kTLS TX zerocopy sendfile. + // Only enabled if the hardware supports the protocol. + // Otherwise, get an error message which is fine. + ktlsEnableTxZerocopySendfile(sock.Fd()) } else { Debugln("kTLS: TLS_TX unsupported connection type") } @@ -175,6 +190,13 @@ func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []by ulpEnabled = true Debugln("kTLS: TLS_RX enabled") c.in.cipher = kTLSCipher{} + // Only enable the TLS_RX_EXPECT_NO_PAD for TLS 1.3 + // TODO: safe to enable only if the remote end is trusted, otherwise + // it is an attack vector to doubling the TLS processing cost. + // See: https://docs.kernel.org/networking/tls.html#tls-rx-expect-no-pad + // if version == VersionTLS13 { + // ktlsEnableRxExpectNoPad(sock.Fd()) + // } } else { Debugln("kTLS: TLS_RX unsupported connection type") } @@ -365,3 +387,23 @@ func ktlsEnableCHACHA20POLY1305(fd int, version uint16, opt int, skip bool, key, return err } + +func ktlsEnableTxZerocopySendfile(fd int) (err error) { + err = syscall.SetsockoptInt(int(fd), SOL_TLS, TLS_TX_ZEROCOPY_RO, 1) + if err != nil { + Debugf("kTLS: TLS_TX Zerocopy Sendfile not Enabled. Error: %s", err) + return + } + Debugln("kTLS: TLS_TX Zerocopy Sendfile Enabled") + return +} + +func ktlsEnableRxExpectNoPad(fd int) (err error) { + err = syscall.SetsockoptInt(int(fd), SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1) + if err != nil { + Debugf("kTLS: TLS_RX Expect No Pad not Enabled. Error: %s", err) + return + } + Debugln("kTLS: TLS_RX Expect No Pad Enabled") + return +} diff --git a/pkg/tls/ktls_linux.go b/pkg/tls/ktls_linux.go index 61c0a758a..8ef0687c7 100644 --- a/pkg/tls/ktls_linux.go +++ b/pkg/tls/ktls_linux.go @@ -20,8 +20,10 @@ const ( TCP_ULP = 31 SOL_TLS = 282 - TLS_TX = 1 - TLS_RX = 2 + TLS_TX = 1 + TLS_RX = 2 + TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now) + TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only TLS_SET_RECORD_TYPE = 1 TLS_GET_RECORD_TYPE = 2 From 29768bcd5211b22bbfcc8bbc849a45064e8482f2 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:45:59 +0000 Subject: [PATCH 09/34] bug: Fix KTLS readRecordOrCCS return EOF data should use the local declaration rather than re-declaring in the if statement, which results len(data) is 0 on line 794, resulting EOF. --- pkg/tls/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index bf8bd5d36..132c2bda5 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -651,7 +651,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { if _, ok := c.in.cipher.(kTLSCipher); ok { dataPtr := ktlsInBufPool.Get().(*[]byte) - data := *dataPtr + data = *dataPtr defer func() { // You might be tempted to simplify this by just passing &outBuf to Put, // but that would make the local copy of the outBuf slice header escape From ee43463287bb1ae2972b56aa843bbb6666e902ed Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:47:23 +0000 Subject: [PATCH 10/34] change int(fd) to fd as fd is already an int. --- pkg/tls/ktls_cipher_linux.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/tls/ktls_cipher_linux.go b/pkg/tls/ktls_cipher_linux.go index ade07136d..c6b6f0c77 100644 --- a/pkg/tls/ktls_cipher_linux.go +++ b/pkg/tls/ktls_cipher_linux.go @@ -255,13 +255,13 @@ func ktlsEnableAES128GCM(fd int, version uint16, opt int, skip bool, key, iv, se } if !skip { - err = syscall.SetsockoptString(int(fd), syscall.SOL_TCP, TCP_ULP, "tls") + err = syscall.SetsockoptString(fd, syscall.SOL_TCP, TCP_ULP, "tls") if err != nil { Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) return } } - err = syscall.SetsockoptString(int(fd), SOL_TLS, opt, + err = syscall.SetsockoptString(fd, SOL_TLS, opt, string((*[kTLSCryptoInfoSize_AES_GCM_128]byte)(unsafe.Pointer(&cryptoInfo))[:])) if err != nil { Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) @@ -322,13 +322,13 @@ func ktlsEnableAES256GCM(fd int, version uint16, opt int, skip bool, key, iv, se } if !skip { - err = syscall.SetsockoptString(int(fd), syscall.SOL_TCP, TCP_ULP, "tls") + err = syscall.SetsockoptString(fd, syscall.SOL_TCP, TCP_ULP, "tls") if err != nil { Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) return } } - err = syscall.SetsockoptString(int(fd), SOL_TLS, opt, + err = syscall.SetsockoptString(fd, SOL_TLS, opt, string((*[kTLSCryptoInfoSize_AES_GCM_256]byte)(unsafe.Pointer(&cryptoInfo))[:])) if err != nil { Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) @@ -372,13 +372,13 @@ func ktlsEnableCHACHA20POLY1305(fd int, version uint16, opt int, skip bool, key, } if !skip { - err = syscall.SetsockoptString(int(fd), syscall.SOL_TCP, TCP_ULP, "tls") + err = syscall.SetsockoptString(fd, syscall.SOL_TCP, TCP_ULP, "tls") if err != nil { Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) return } } - err = syscall.SetsockoptString(int(fd), SOL_TLS, opt, + err = syscall.SetsockoptString(fd, SOL_TLS, opt, string((*[kTLSCryptoInfoSize_CHACHA20_POLY1305]byte)(unsafe.Pointer(&cryptoInfo))[:])) if err != nil { Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) From 8e71e26073a9a9824174d61c57c36de92c80b69b Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 25 Jan 2023 19:51:23 +0000 Subject: [PATCH 11/34] Bug: Fix kTLS 1.3 RX not working on kernel 5.15 ======================================= 1. disable kTLS 1.3 RX on kernel 5.15 2. check zero copy on kernel 5.19 3. check tls 1.3 no pad on kernel 6.0 --- pkg/tls/handshake_client_tls13.go | 1 + pkg/tls/ktls_cipher_linux.go | 26 +++++++++------- pkg/tls/ktls_linux.go | 49 ++++++++++++++++++++++++++----- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go index c9b11c819..8a09b8a4b 100644 --- a/pkg/tls/handshake_client_tls13.go +++ b/pkg/tls/handshake_client_tls13.go @@ -109,6 +109,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { return err } + return nil } diff --git a/pkg/tls/ktls_cipher_linux.go b/pkg/tls/ktls_cipher_linux.go index c6b6f0c77..ceda18c87 100644 --- a/pkg/tls/ktls_cipher_linux.go +++ b/pkg/tls/ktls_cipher_linux.go @@ -124,8 +124,8 @@ func ktlsEnableAES( Debugln("kTLS: TLS_TX unsupported key length") } - // Try to enable Kernel TLS RX - if !kTLSSupportRX { + // Try to enable Kernel TLS RX for TLS 1.2 or TLS 1.3 (TLS 1.3 RX is disabled on kernel < 5.19 ) + if !kTLSSupportRX || (version == VersionTLS13 && !kTLSSupportTLS13RX) { return nil } if len(inKey) == keyLen { @@ -140,9 +140,9 @@ func ktlsEnableAES( // TODO: safe to enable only if the remote end is trusted, otherwise // it is an attack vector to doubling the TLS processing cost. // See: https://docs.kernel.org/networking/tls.html#tls-rx-expect-no-pad - // if version == VersionTLS13 { - // ktlsEnableRxExpectNoPad(sock.Fd()) - // } + if version == VersionTLS13 { + ktlsEnableRxExpectNoPad(sock.Fd()) + } } else { Debugln("kTLS: TLS_RX unsupported connection type") } @@ -177,8 +177,8 @@ func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []by Debugln("kTLS: TLS_TX unsupported connection type") } - // Try to enable Kernel TLS RX - if !kTLSSupportRX { + // Try to enable Kernel TLS RX for TLS 1.2 or TLS 1.3 (TLS 1.3 RX is disabled on kernel < 5.19 ) + if !kTLSSupportRX || (version == VersionTLS13 && !kTLSSupportTLS13RX) { return nil } if sock, ok := c.conn.(Socket); ok { @@ -194,9 +194,9 @@ func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []by // TODO: safe to enable only if the remote end is trusted, otherwise // it is an attack vector to doubling the TLS processing cost. // See: https://docs.kernel.org/networking/tls.html#tls-rx-expect-no-pad - // if version == VersionTLS13 { - // ktlsEnableRxExpectNoPad(sock.Fd()) - // } + if version == VersionTLS13 { + ktlsEnableRxExpectNoPad(sock.Fd()) + } } else { Debugln("kTLS: TLS_RX unsupported connection type") } @@ -389,6 +389,9 @@ func ktlsEnableCHACHA20POLY1305(fd int, version uint16, opt int, skip bool, key, } func ktlsEnableTxZerocopySendfile(fd int) (err error) { + if !kTLSSupportZEROCOPY { + return nil + } err = syscall.SetsockoptInt(int(fd), SOL_TLS, TLS_TX_ZEROCOPY_RO, 1) if err != nil { Debugf("kTLS: TLS_TX Zerocopy Sendfile not Enabled. Error: %s", err) @@ -399,6 +402,9 @@ func ktlsEnableTxZerocopySendfile(fd int) (err error) { } func ktlsEnableRxExpectNoPad(fd int) (err error) { + if !kTLSSupportNOPAD { + return nil + } err = syscall.SetsockoptInt(int(fd), SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1) if err != nil { Debugf("kTLS: TLS_RX Expect No Pad not Enabled. Error: %s", err) diff --git a/pkg/tls/ktls_linux.go b/pkg/tls/ktls_linux.go index 8ef0687c7..4eb18dc04 100644 --- a/pkg/tls/ktls_linux.go +++ b/pkg/tls/ktls_linux.go @@ -43,7 +43,16 @@ var ( kTLSSupportAESGCM256 bool kTLSSupportCHACHA20POLY1305 bool - kTLSSupportTLS13 bool + kTLSSupportTLS13TX bool + // TLS1.3 RX is buggy in kernel 5.15, got weird package lost + // TODO: test it on kernel 5.19 or 6+ + kTLSSupportTLS13RX bool + + // available in kernel >= 5.19 or 6+ + kTLSSupportZEROCOPY bool + + // available in kernel 6+ + kTLSSupportNOPAD bool ) func init() { @@ -84,6 +93,8 @@ func init() { return } + Debugf("Kernel Version: %s\n", release) + if (major == 4 && minor >= 13) || major > 4 { kTLSSupportTX = true kTLSSupportAESGCM128 = true @@ -95,12 +106,34 @@ func init() { if (major == 5 && minor >= 1) || major > 5 { kTLSSupportAESGCM256 = true - kTLSSupportTLS13 = true + kTLSSupportTLS13TX = true } if (major == 5 && minor >= 11) || major > 5 { kTLSSupportCHACHA20POLY1305 = true } + + if (major == 5 && minor >= 19) || major > 5 { + kTLSSupportZEROCOPY = true + kTLSSupportTLS13RX = true + } + + if major > 5 { + kTLSSupportNOPAD = true + } + + Debugln("======Supported Features======") + Debugf("kTLS TX: %v\n", kTLSSupportTX) + Debugf("kTLS RX: %v\n", kTLSSupportRX) + Debugf("kTLS TLS 1.3 TX: %v\n", kTLSSupportTLS13TX) + Debugf("kTLS TLS 1.3 RX: %v\n", kTLSSupportTLS13RX) + Debugf("kTLS TX ZeroCopy: %v\n", kTLSSupportZEROCOPY) + Debugf("kTLS RX Expected No Pad: %v\n", kTLSSupportNOPAD) + + Debugln("=========CipherSuites=========") + Debugf("kTLS AES-GCM-128: %v\n", kTLSSupportAESGCM128) + Debugf("kTLS AES-GCM-256: %v\n", kTLSSupportAESGCM256) + Debugf("kTLS CHACHA20POLY1305: %v\n", kTLSSupportCHACHA20POLY1305) } func (c *Conn) ReadFrom(r io.Reader) (n int64, err error) { @@ -307,36 +340,36 @@ func (c *Conn) enableKernelTLS(cipherSuiteID uint16, inKey, outKey, inIV, outIV if !kTLSSupportAESGCM128 { return nil } - Debugln("try to enable kernel tls AES_128_GCM") + Debugln("try to enable kernel tls AES_128_GCM for tls 1.2") return ktlsEnableAES(c, VersionTLS12, ktlsEnableAES128GCM, 16, inKey, outKey, inIV, outIV) case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_RSA_WITH_AES_256_GCM_SHA384: if !kTLSSupportAESGCM256 { return nil } - Debugln("try to enable kernel tls AES_256_GCM") + Debugln("try to enable kernel tls AES_256_GCM for tls 1.2") return ktlsEnableAES(c, VersionTLS12, ktlsEnableAES256GCM, 32, inKey, outKey, inIV, outIV) case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: if !kTLSSupportCHACHA20POLY1305 { return nil } - Debugln("try to enable kernel tls CHACHA20_POLY1305") + Debugln("try to enable kernel tls CHACHA20_POLY1305 for tls 1.2") return ktlsEnableCHACHA20(c, VersionTLS12, inKey, outKey, inIV, outIV) // Kernel TLS 1.3 case TLS_AES_128_GCM_SHA256: - if !kTLSSupportAESGCM128 { + if !kTLSSupportAESGCM128 || !kTLSSupportTLS13TX { return nil } Debugln("try to enable kernel tls AES_128_GCM for tls 1.3") return ktlsEnableAES(c, VersionTLS13, ktlsEnableAES128GCM, 16, inKey, outKey, inIV, outIV) case TLS_AES_256_GCM_SHA384: - if !kTLSSupportAESGCM256 { + if !kTLSSupportAESGCM256 || !kTLSSupportTLS13TX { return nil } Debugln("try to enable kernel tls AES_256_GCM tls 1.3") return ktlsEnableAES(c, VersionTLS13, ktlsEnableAES256GCM, 32, inKey, outKey, inIV, outIV) case TLS_CHACHA20_POLY1305_SHA256: - if !kTLSSupportCHACHA20POLY1305 { + if !kTLSSupportCHACHA20POLY1305 || !kTLSSupportTLS13TX { return nil } Debugln("try to enable kernel tls CHACHA20_POLY1305 for tls 1.3") From 3e95281980e841bd276d9a94c268ecad51f3a2e6 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 25 Jan 2023 19:56:42 +0000 Subject: [PATCH 12/34] comment out dead code --- pkg/tls/conn.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 132c2bda5..6a562f363 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -279,8 +279,9 @@ func (hc *halfConn) explicitNonceLen() int { return c.BlockSize() } return 0 - case kTLSCipher: - return 0 + // never reached, thus dead code + // case kTLSCipher: + // return 0 default: panic("unknown cipher type") } @@ -903,8 +904,9 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { // The MAC is appended before padding so affects the // payload size directly. payloadBytes -= c.out.mac.Size() - case kTLSCipher: - payloadBytes -= kTLSOverhead + // never reached, thus dead code + // case kTLSCipher: + // payloadBytes -= kTLSOverhead default: panic("unknown cipher type") } From af390887bca00914e159f702c853d684d2b2da3d Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 25 Jan 2023 19:57:20 +0000 Subject: [PATCH 13/34] update go version to 1.20 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index ac113fc99..97216c37c 100644 --- a/go.mod +++ b/go.mod @@ -19,4 +19,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.17 +go 1.20 From 492f83ed480fdf648addff3eb6e9316cf22ea981 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Thu, 26 Jan 2023 15:57:51 +0000 Subject: [PATCH 14/34] TLS: optimize checking if sendBuf is empty or not --- pkg/tls/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 6a562f363..9b208f008 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -938,7 +938,7 @@ func (c *Conn) write(data []byte) (int, error) { } func (c *Conn) flush() (int, error) { - if c.sendBuf.Buffered() == 0 { + if c.sendBuf.IsEmpty() { return 0, nil } n, err := c.conn.Write(nil) From 94ad7e8d5dd51f3160ac80d1e5a75d5d90b2ded7 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Fri, 27 Jan 2023 05:59:11 +0000 Subject: [PATCH 15/34] opt: TLS writes the data into the socket directly ====================================== 1. TLS writes the data into the socket directly rather than writing the data into the buffer. the data is buffered only if error unix.EAGAIN occurs. 2. Add "tlsEnabled bool" to control when to use tlsconn.Write(). The reason is that tlsconn.Write() encrypt the data, then calls gnetConn.Write() which could potently call either gnetConn.write() or gnetConn.writeTLS(). Therefore, we make "tlsEnabled" to false before calling tlsconn.Write(), and then restore "tlsEnabled" to true after that. 3. tlsconn.flush() calls gnetConn.Flush() to flush the buffer immediately. Therefore, we don't need to call gnetConn.Flush() in gnet TLS handshake phase as tlsconn.Handshake() calls gnetConn.Flush() implicitly. --- connection.go | 81 ++++++++++++++++++++++--------- eventloop.go | 29 ++++++++++- pkg/tls/conn.go | 29 ++++++----- pkg/tls/handshake_client.go | 2 +- pkg/tls/handshake_client_tls13.go | 2 +- pkg/tls/handshake_server.go | 4 +- pkg/tls/handshake_server_tls13.go | 2 +- 7 files changed, 108 insertions(+), 41 deletions(-) diff --git a/connection.go b/connection.go index 91cab70be..632f7fad9 100644 --- a/connection.go +++ b/connection.go @@ -50,6 +50,7 @@ type conn struct { isDatagram bool // UDP protocol opened bool // connection opened event fired tlsconn *tls.Conn // tls connection + tlsEnabled bool // whether TLS is enabled } func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, localAddr, remoteAddr net.Addr) (c *conn) { @@ -83,6 +84,7 @@ func (c *conn) releaseTCP() { c.outboundBuffer.Release() netpoll.PutPollAttachment(c.pollAttachment) c.pollAttachment = nil + c.tlsEnabled = false } func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, connected bool) (c *conn) { @@ -129,19 +131,22 @@ func (c *conn) open(buf []byte) error { return err } +func (c *conn) writeTLS(data []byte) (n int, err error) { + // temporarily disable the TLS connection. + c.tlsEnabled = false + // use tls to encrypt the data before sending it. + // tlsconn will implicitly call gnet.Write (but it runs the plaintext version) to sent the data. + // unsent data will be buffered + n, err = c.tlsconn.Write(data) + // re-enable the TLS connection if data is sent or is buffered + // Otherwise, the connection is closed (c.loop.closeConn() is called). + c.tlsEnabled = true + return +} + func (c *conn) write(data []byte) (n int, err error) { n = len(data) - if c.tlsconn != nil { - // use tls to encrypt the data before sending it - n, _ = c.tlsconn.Write(data) - // err = c.loop.poller.ModReadWrite(c.pollAttachment) - // n = 0 - // also working - err = c.loop.write(c) - return - } - // If there is pending data in outbound buffer, the current data ought to be appended to the outbound buffer // for maintaining the sequence of network packets. if !c.outboundBuffer.IsEmpty() { @@ -167,21 +172,37 @@ func (c *conn) write(data []byte) (n int, err error) { return } -func (c *conn) writev(bs [][]byte) (n int, err error) { +func (c *conn) writevTLS(bs [][]byte) (n int, err error) { for _, b := range bs { n += len(b) } - if c.tlsconn != nil { - for _, b := range bs { - // use tls to encrypt the data before sending it - c.tlsconn.Write(b) + // temporarily disable the TLS connection. + c.tlsEnabled = false + // use tls to encrypt the data before sending it. + // tlsconn will implicitly call gnet.Write (but it runs the plaintext version) to sent the data. + // unsent data will be buffered + sent := 0 + var sentN int + for _, b := range bs { + sentN, err = c.tlsconn.Write(b) + if sentN < 0 { + // the connection is closed (c.loop.closeConn() is called). + return sent, err } - // err = c.loop.poller.ModReadWrite(c.pollAttachment) - // n = 0 - // also working - err = c.loop.write(c) - return + sent += sentN + } + + // re-enable the TLS connection if data is sent or is buffered + // Otherwise, the connection is closed (c.loop.closeConn() is called). + c.tlsEnabled = true + + return +} + +func (c *conn) writev(bs [][]byte) (n int, err error) { + for _, b := range bs { + n += len(b) } // If there is pending data in outbound buffer, the current data ought to be appended to the outbound buffer @@ -230,7 +251,11 @@ func (c *conn) asyncWrite(itf interface{}) (err error) { } hook := itf.(*asyncWriteHook) - _, err = c.write(hook.data) + if c.tlsEnabled { + _, err = c.writeTLS(hook.data) + } else { + _, err = c.write(hook.data) + } if hook.callback != nil { _ = hook.callback(c, err) } @@ -248,7 +273,11 @@ func (c *conn) asyncWritev(itf interface{}) (err error) { } hook := itf.(*asyncWritevHook) - _, err = c.writev(hook.data) + if c.tlsEnabled { + _, err = c.writevTLS(hook.data) + } else { + _, err = c.writev(hook.data) + } if hook.callback != nil { _ = hook.callback(c, err) } @@ -373,6 +402,9 @@ func (c *conn) Write(p []byte) (int, error) { } return len(p), nil } + if c.tlsEnabled { + return c.writeTLS(p) + } return c.write(p) } @@ -380,6 +412,9 @@ func (c *conn) Writev(bs [][]byte) (int, error) { if c.isDatagram { return 0, gerrors.ErrUnsupportedOp } + if c.tlsEnabled { + return c.writevTLS(bs) + } return c.writev(bs) } @@ -494,7 +529,9 @@ func (c *conn) Close() error { } func (c *conn) UpgradeTLS(config *tls.Config) (err error) { + // TODO: create a sync.pool to manage the TLS connection c.tlsconn = tls.ServerGnet(c, &c.inboundBuffer, c.outboundBuffer, config.Clone()) + c.tlsEnabled = true //很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 if c.inboundBuffer.Len() > 0 { diff --git a/eventloop.go b/eventloop.go index e7141397e..855aff813 100644 --- a/eventloop.go +++ b/eventloop.go @@ -167,7 +167,6 @@ func (el *eventloop) read(c *conn) error { return el.closeConn(c, os.NewSyscallError("TLS handshake", err)) } if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawData()) == 0 { //握手没成功,或者握手成功,但是没有数据黏包了 - c.Flush() return nil } } @@ -241,8 +240,34 @@ func (el *eventloop) closeConn(c *conn, err error) (rerr error) { } // close the TLS connection by sending the alert + // + // tlsconn.Close() is called only if err == nil. + // Notice tlsconn.Close() eventually calls gnetConn.write(). + // gnetConn.write() is possible to call c.loop.closeConn() again if unix.Write() returns an error other than unix.EAGAIN, + // which creates a cycle. The error message could be "broken pipe" or "connection reset by peer" indicating + // the socket is no longer valid. + // + // This implies that once "err is not nil", it will create a cycle, and run the code in an infinite loop. + // Therefore, we use "if err == nil" to detect the cycle and break it. if c.tlsconn != nil { - c.tlsconn.Close() + // The default call graph results calling gnet.writeTLS(). See below + // + // tlsconn.Close() -> tlsconn.sendAlertLocked() -> tlsconn.writeRecordLocked() -> tlsconn.write() -> + // gnetConn.Write() (Here tlsEnabled is true, implying to run the TLS version) -> + // gnetConn.writeTLS() -> tlsconn.Write() -> tlsconn.writeRecordLocked() -> tlsconn.write() -> + // gnetConn.Write() (Here tlsEnabled is false, implying to run the plaintext version) -> + // gnetConn.write() + // + // Therefore, the closing message is encrypted twice which is not correct. + // To resolve the issue, we disable the TLS before closing. Then, when tlsconn.Close() + // calls gnetConn.Write(), which immediately runs the plaintext version. This means + // the encrypted data is written to the socket directly instead of being encrypted one more time. + c.tlsEnabled = false + if err == nil { + c.tlsconn.Close() + } + c.tlsconn = nil + // TODO: create a sync.pool to manage the TLS connection } // Send residual data in buffer back to the peer before actually closing the connection. diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 9b208f008..d8f62d85f 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -26,6 +26,7 @@ import ( type Socket interface { // Fd returns the underlying file descriptor. Fd() int + Flush() error } // A Conn represents a secured connection. @@ -104,12 +105,12 @@ type Conn struct { // By using the elastic MsgBuffer the tls conn not longer holds the actual buffer when the connection is idle. // This can significantly optimize the memory usage, especially when the server connecting millions of clients // where most of them are idle. - in, out halfConn - rawInput EMsgBuffer // raw input, starting with a record header - input *elastic.RingBuffer // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn - hand EMsgBuffer // handshake data waiting to be read - // buffering bool // whether records are buffered in sendBuf - sendBuf *elastic.Buffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn + in, out halfConn + rawInput EMsgBuffer // raw input, starting with a record header + input *elastic.RingBuffer // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn + hand EMsgBuffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf *elastic.Buffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn // bytesSent counts the bytes of application data sent. // packetsSent counts packets. @@ -930,20 +931,24 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { } func (c *Conn) write(data []byte) (int, error) { - //必须把所有数据往buf写 - n := len(data) - c.sendBuf.Write(data) + if c.buffering { + _, _ = c.sendBuf.Write(data) + return len(data), nil + } + + n, err := c.conn.Write(data) c.bytesSent += int64(n) - return n, nil + return n, err } func (c *Conn) flush() (int, error) { if c.sendBuf.IsEmpty() { return 0, nil } - n, err := c.conn.Write(nil) + n := c.sendBuf.Buffered() c.bytesSent += int64(n) - // c.buffering = false + err := c.conn.(Socket).Flush() + c.buffering = false return n, err } diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go index 171eaccd9..fbba66a50 100644 --- a/pkg/tls/handshake_client.go +++ b/pkg/tls/handshake_client.go @@ -418,7 +418,7 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.Write(hs.hello.marshal()) hs.finishedHash.Write(hs.serverHello.marshal()) c.handshakeStatus = 3 - //c.buffering = true + c.buffering = true } if c.didResume { diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go index 8a09b8a4b..52a75954e 100644 --- a/pkg/tls/handshake_client_tls13.go +++ b/pkg/tls/handshake_client_tls13.go @@ -75,7 +75,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { hs.transcript.Write(hs.serverHello.marshal()) - //c.buffering = true + c.buffering = true if err := hs.processServerHello(); err != nil { return err } diff --git a/pkg/tls/handshake_server.go b/pkg/tls/handshake_server.go index 361a9f6e9..b6d8f5c9b 100644 --- a/pkg/tls/handshake_server.go +++ b/pkg/tls/handshake_server.go @@ -79,7 +79,7 @@ func (hs *serverHandshakeState) handshake() error { } // For an overview of TLS handshaking, see RFC 5246, Section 7.3. - //c.buffering = true + c.buffering = true } if hs.checkForResumption() { @@ -143,7 +143,7 @@ func (hs *serverHandshakeState) handshake() error { return err } c.clientFinishedIsFirst = true - //c.buffering = true + c.buffering = true if err := hs.sendSessionTicket(); err != nil { return err } diff --git a/pkg/tls/handshake_server_tls13.go b/pkg/tls/handshake_server_tls13.go index 860518ecb..ce879c67b 100644 --- a/pkg/tls/handshake_server_tls13.go +++ b/pkg/tls/handshake_server_tls13.go @@ -61,7 +61,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error { if err := hs.pickCertificate(); err != nil { return err } - //c.buffering = true + c.buffering = true if err := hs.sendServerParameters(); err != nil { return err } From 43bf39fdf26816cfa361bf4b50465d491fe52a5e Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Fri, 27 Jan 2023 06:00:34 +0000 Subject: [PATCH 16/34] opt: don't check kTLS supports if kTLS is disabled --- pkg/tls/ktls_linux.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/tls/ktls_linux.go b/pkg/tls/ktls_linux.go index 4eb18dc04..f26ac0e96 100644 --- a/pkg/tls/ktls_linux.go +++ b/pkg/tls/ktls_linux.go @@ -64,6 +64,10 @@ func init() { } kTLSSupport = true && kTLSEnabled Debugf("kTLS Enabled Status: %v\n", kTLSSupport) + // no need to check further, as KTLS is disabled + if !kTLSSupport { + return + } var uname syscall.Utsname if err := syscall.Uname(&uname); err != nil { From 76acc42e10a073cb8a95d332f7c2a15ba325a7d9 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Fri, 27 Jan 2023 20:39:46 +0000 Subject: [PATCH 17/34] opt: remove the dead code --- pkg/tls/ktls_cipher_linux.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/tls/ktls_cipher_linux.go b/pkg/tls/ktls_cipher_linux.go index ceda18c87..1ac2c361b 100644 --- a/pkg/tls/ktls_cipher_linux.go +++ b/pkg/tls/ktls_cipher_linux.go @@ -187,7 +187,6 @@ func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []by Debugln("kTLS: TLS_RX error enabling:", err) return err } - ulpEnabled = true Debugln("kTLS: TLS_RX enabled") c.in.cipher = kTLSCipher{} // Only enable the TLS_RX_EXPECT_NO_PAD for TLS 1.3 From c377ecedbcb67437507f0f7d9f93566346f6798c Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 30 Jan 2023 17:43:44 +0000 Subject: [PATCH 18/34] opt: zero-copy buffer in gnet TLS implementation ======================================== Redesign the buffer in gnet TLS implementation to achieve zero-copy. Background: - tlsconn.rawInput: raw input from TCP to hold the TLS record - tlsconn.input: buffer to hold decrypted TLS record - tlsconn.hand: buffer to hold handshake data - tlsconn.sendBuf: buffer to hold sending data Problems: - Memory copy in TLS read: In the previous implementation, tlsconn.input refers to the gnetConn.inboundBuffer. To decrypted, we copy el.buffer to tlsconn.rawInput. The TLS connection, write the decrypted data to tlsconn.input, which is gnetConn.inboundBuffer. When el.eventHandler.OnTraffic() is triggered, gnetConn.Next() and gnet.Conn.Peek() can trigger more data copy as it can write to c.loop.cache() - Memory copy in TLS write: In the previous implementation, all encrypted data are first written to tlsconn.sendBuf, which refers to gnetConn.outboundBuffer. Then, tlsconn.Write() calls gnetConn.Write() which flushes the buffer to the socket New implementation: We designed LazyBuffer (lb) which has a buf []byte and its reference ref *[]byte. In the lazy mode, lb.ref is always nil, lb.buf is readonly. When calling lb.Write(), lb request a buffer from the sync.Pool, and copies lb.buf to the new buffer. Both lb.buf and lb.ref point to the new buffer. - New TLS read: With LazyBuffer, we let tlsconn.rawInput refer to el.buffer. Decrypted data stores in tlsconn.rawInput as well. tlsconn.Data() returns the reference of all decrypted data, and will be assigned to gnetConn.buffer. - New TLS write: tlsconn.Write() first encrypts the data, then calls gnetConn.WriteTCP() which directly writes the data to the socket. - New TLS handshake: we restore the tlsconn.Buffering flag which is only used in the handshake. Incoming handshake data is stored in tlsconn.hand and will be discarded immediately after being used. Outgoing handshake data is buffered in tlsconn.sendBuf, and will be flushed after calling tlsconn.flush() which calls gnetConn.WriteTCP() which directly writes the data to the socket. --- connection.go | 34 +++++---- eventloop.go | 66 +++++++++++------ pkg/tls/bufLazy.go | 141 ++++++++++++++++++++++++++++++++++++ pkg/tls/bufLazy_test.go | 108 ++++++++++++++++++++++++++++ pkg/tls/conn.go | 154 ++++++++++++++++++++++++---------------- pkg/tls/tls.go | 43 +---------- 6 files changed, 405 insertions(+), 141 deletions(-) create mode 100644 pkg/tls/bufLazy.go create mode 100644 pkg/tls/bufLazy_test.go diff --git a/connection.go b/connection.go index 632f7fad9..abbd6acee 100644 --- a/connection.go +++ b/connection.go @@ -50,7 +50,7 @@ type conn struct { isDatagram bool // UDP protocol opened bool // connection opened event fired tlsconn *tls.Conn // tls connection - tlsEnabled bool // whether TLS is enabled + // tlsEnabled bool // whether TLS is enabled } func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, localAddr, remoteAddr net.Addr) (c *conn) { @@ -84,7 +84,7 @@ func (c *conn) releaseTCP() { c.outboundBuffer.Release() netpoll.PutPollAttachment(c.pollAttachment) c.pollAttachment = nil - c.tlsEnabled = false + // c.tlsEnabled = false } func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, connected bool) (c *conn) { @@ -133,14 +133,14 @@ func (c *conn) open(buf []byte) error { func (c *conn) writeTLS(data []byte) (n int, err error) { // temporarily disable the TLS connection. - c.tlsEnabled = false + // c.tlsEnabled = false // use tls to encrypt the data before sending it. // tlsconn will implicitly call gnet.Write (but it runs the plaintext version) to sent the data. // unsent data will be buffered n, err = c.tlsconn.Write(data) // re-enable the TLS connection if data is sent or is buffered // Otherwise, the connection is closed (c.loop.closeConn() is called). - c.tlsEnabled = true + // c.tlsEnabled = true return } @@ -178,7 +178,7 @@ func (c *conn) writevTLS(bs [][]byte) (n int, err error) { } // temporarily disable the TLS connection. - c.tlsEnabled = false + // c.tlsEnabled = false // use tls to encrypt the data before sending it. // tlsconn will implicitly call gnet.Write (but it runs the plaintext version) to sent the data. // unsent data will be buffered @@ -195,7 +195,7 @@ func (c *conn) writevTLS(bs [][]byte) (n int, err error) { // re-enable the TLS connection if data is sent or is buffered // Otherwise, the connection is closed (c.loop.closeConn() is called). - c.tlsEnabled = true + // c.tlsEnabled = true return } @@ -251,7 +251,7 @@ func (c *conn) asyncWrite(itf interface{}) (err error) { } hook := itf.(*asyncWriteHook) - if c.tlsEnabled { + if c.tlsconn != nil { _, err = c.writeTLS(hook.data) } else { _, err = c.write(hook.data) @@ -273,7 +273,7 @@ func (c *conn) asyncWritev(itf interface{}) (err error) { } hook := itf.(*asyncWritevHook) - if c.tlsEnabled { + if c.tlsconn != nil { _, err = c.writevTLS(hook.data) } else { _, err = c.writev(hook.data) @@ -402,17 +402,23 @@ func (c *conn) Write(p []byte) (int, error) { } return len(p), nil } - if c.tlsEnabled { + if c.tlsconn != nil { return c.writeTLS(p) } return c.write(p) } +// Expose the plaintext write API which should only be used +// by tlsconn.Write(). +func (c *conn) WriteTCP(p []byte) (int, error) { + return c.write(p) +} + func (c *conn) Writev(bs [][]byte) (int, error) { if c.isDatagram { return 0, gerrors.ErrUnsupportedOp } - if c.tlsEnabled { + if c.tlsconn != nil { return c.writevTLS(bs) } return c.writev(bs) @@ -530,14 +536,14 @@ func (c *conn) Close() error { func (c *conn) UpgradeTLS(config *tls.Config) (err error) { // TODO: create a sync.pool to manage the TLS connection - c.tlsconn = tls.ServerGnet(c, &c.inboundBuffer, c.outboundBuffer, config.Clone()) - c.tlsEnabled = true + c.tlsconn = tls.Server(c, config.Clone()) + // c.tlsEnabled = true //很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 if c.inboundBuffer.Len() > 0 { head, tail := c.inboundBuffer.Peek(-1) - c.tlsconn.RawWrite(head) - c.tlsconn.RawWrite(tail) + c.tlsconn.RawInputSet(head) + c.tlsconn.RawInputSet(tail) c.inboundBuffer.Reset() if err := c.tlsconn.Handshake(); err != nil { return err diff --git a/eventloop.go b/eventloop.go index 855aff813..84876ec11 100644 --- a/eventloop.go +++ b/eventloop.go @@ -115,23 +115,35 @@ func (el *eventloop) open(c *conn) error { } func (el *eventloop) readTLS(c *conn) error { - if err := c.tlsconn.ReadFrame(); err != nil { - return el.closeConn(c, os.NewSyscallError("TLS read", err)) - } + // Since the el.Buffer may contain multiple TLS record, + // we process one TLS record in each iteration until no more + // TLS records are available + for { + if err := c.tlsconn.ReadFrame(); err != nil { + return el.closeConn(c, os.NewSyscallError("TLS read", err)) + } - if c.inboundBuffer.IsEmpty() { - return nil - } + // load all decrypted data and make it ready for gnet to use + c.buffer = c.tlsconn.Data() - action := el.eventHandler.OnTraffic(c) - switch action { - case None: - case Close: - return el.closeConn(c, nil) - case Shutdown: - return gerrors.ErrEngineShutdown + action := el.eventHandler.OnTraffic(c) + switch action { + case None: + case Close: + // tls data will be cleaned up in el.closeConn() + return el.closeConn(c, nil) + case Shutdown: + c.tlsconn.DataDone() + return gerrors.ErrEngineShutdown + } + _, _ = c.inboundBuffer.Write(c.buffer) + + // all available TLS records are processed + if !c.tlsconn.IsRecordCompleted(c.tlsconn.RawInputData()) { + c.tlsconn.DataDone() + return nil + } } - return nil } func (el *eventloop) read(c *conn) error { @@ -139,8 +151,10 @@ func (el *eventloop) read(c *conn) error { // This only happens after TLS handshake is completed. // Therefore, no need to call c.tlsconn.HandshakeComplete() // In addition, all data are copied directly from kernel to the buffer, - // meaning no need to call unix.read(c.fd, el.buffer) + // el.buffer meaning no need to call unix.read(c.fd, el.buffer) if c.tlsconn != nil && c.tlsconn.IsKTLSRXEnabled() { + // attach the gnet eventloop.buffer to tlsconn.rawInput + c.tlsconn.RawInputSet(el.buffer) return el.readTLS(c) } @@ -156,17 +170,21 @@ func (el *eventloop) read(c *conn) error { } if c.tlsconn != nil { - c.tlsconn.RawWrite(el.buffer[:n]) + c.tlsconn.RawInputSet(el.buffer[:n]) if !c.tlsconn.HandshakeComplete() { //先判断是否足够一条消息 - data := c.tlsconn.RawData() - if len(data) < 5 || len(data) < 5+int(data[3])<<8|int(data[4]) { + data := c.tlsconn.RawInputData() + if !c.tlsconn.IsRecordCompleted(data) { + c.tlsconn.DataDone() return nil } if err = c.tlsconn.Handshake(); err != nil { + // closeConn will cleanup the TLS data at the end, + // so need to call tlsconn.DataDone() return el.closeConn(c, os.NewSyscallError("TLS handshake", err)) } - if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawData()) == 0 { //握手没成功,或者握手成功,但是没有数据黏包了 + if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawInputData()) == 0 { //握手没成功,或者握手成功,但是没有数据黏包了 + c.tlsconn.DataDone() return nil } } @@ -262,10 +280,12 @@ func (el *eventloop) closeConn(c *conn, err error) (rerr error) { // To resolve the issue, we disable the TLS before closing. Then, when tlsconn.Close() // calls gnetConn.Write(), which immediately runs the plaintext version. This means // the encrypted data is written to the socket directly instead of being encrypted one more time. - c.tlsEnabled = false - if err == nil { - c.tlsconn.Close() - } + // c.tlsEnabled = false + // if err == nil { + c.tlsconn.Close() + // } + // Make sure all memory requested from the pool is returned. + c.tlsconn.DataCleanUpAfterClose() c.tlsconn = nil // TODO: create a sync.pool to manage the TLS connection } diff --git a/pkg/tls/bufLazy.go b/pkg/tls/bufLazy.go new file mode 100644 index 000000000..0ae104e2c --- /dev/null +++ b/pkg/tls/bufLazy.go @@ -0,0 +1,141 @@ +package tls + +import ( + "sync" +) + +var bytePool = sync.Pool{ + New: func() any { + buf := make([]byte, defaultSize) + return &buf + }, +} + +type LazyBuffer struct { + buf []byte + ref *[]byte +} + +func (lb *LazyBuffer) Bytes() []byte { + return lb.buf +} + +func (lb *LazyBuffer) Len() int { + return len(lb.buf) +} + +func (lb *LazyBuffer) Next(n int) []byte { + m := lb.Len() + if n > m { + n = m + } + data := lb.buf[:n] + lb.buf = lb.buf[n:] + return data +} + +func (lb *LazyBuffer) tryGrowByReslice(n int) (int, bool) { + if l := len(lb.buf); n <= cap(lb.buf)-l { + lb.buf = lb.buf[:l+n] + return l, true + } + return 0, false +} + +func (lb *LazyBuffer) growSlice(n int) { + // TODO(http://golang.org/issue/51462): We should rely on the append-make + // pattern so that the compiler can call runtime.growslice. For example: + // return append(b, make([]byte, n)...) + // This avoids unnecessary zero-ing of the first len(b) bytes of the + // allocated slice, but this pattern causes b to escape onto the heap. + // + // Instead use the append-make pattern with a nil slice to ensure that + // we allocate buffers rounded up to the closest size class. + c := len(lb.buf) + n // ensure enough space for n elements + if c < 2*cap(*lb.ref) { + // The growth rate has historically always been 2x. In the future, + // we could rely purely on append to determine the growth rate. + c = 2 * cap(*lb.ref) + } + b2 := append([]byte(nil), make([]byte, c)...) + copy(b2, lb.buf) + lb.buf = b2 + lb.ref = &b2 +} + +func (lb *LazyBuffer) grow(n int) int { + m := lb.Len() + // Try to grow by means of a reslice. + if i, ok := lb.tryGrowByReslice(n); ok { + return i + } + + if lb.ref == nil { + lb.ref = bytePool.Get().(*[]byte) + } + + c := cap(*lb.ref) + if n <= c/2-m { + // We can slide things down instead of allocating a new + // slice. We only need m+n <= c to slide, but + // we instead let capacity get twice as large so we + // don't spend all our time copying. + copy(*lb.ref, lb.buf) + } else { + // Add b.off to account for b.buf[:b.off] being sliced off the front. + lb.growSlice(n) + } + + lb.buf = (*lb.ref)[:m+n] + return m +} + +func (lb *LazyBuffer) Grow(n int) { + m := lb.grow(n) + lb.buf = lb.buf[:m] +} + +func (lb *LazyBuffer) Extend(n int) { + lb.grow(n) +} + +func (lb *LazyBuffer) Truncate(n int) { + if 0 <= n && n <= len(lb.buf) { + lb.buf = lb.buf[:n] + } +} + +func (lb *LazyBuffer) Write(p []byte) (n int, err error) { + m := len(lb.buf) + if lb.ref == nil { + oldData := lb.buf + lb.ref = bytePool.Get().(*[]byte) + lb.buf = (*lb.ref)[:0] + lb.grow(m + len(p)) + copy(lb.buf, oldData) + } else { + lb.grow(len(p)) + } + + return copy(lb.buf[m:], p), nil +} + +func (lb *LazyBuffer) Set(p []byte) { + if lb.ref == nil && lb.Len() == 0 { + lb.buf = p + } else { + lb.Write(p) + } +} + +func (lb *LazyBuffer) Done() { + lb.buf = nil + if lb.ref != nil { + bytePool.Put(lb.ref) + lb.ref = nil + } +} + +func (lb *LazyBuffer) IsLazy() bool { + return lb.ref == nil +} diff --git a/pkg/tls/bufLazy_test.go b/pkg/tls/bufLazy_test.go new file mode 100644 index 000000000..61064596c --- /dev/null +++ b/pkg/tls/bufLazy_test.go @@ -0,0 +1,108 @@ +package tls + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func eval(t *testing.T, buf *LazyBuffer, data []byte, n int, isLazy bool) { + assert.EqualValues(t, buf.Bytes(), data) + assert.EqualValues(t, buf.Len(), n) + assert.EqualValues(t, buf.IsLazy(), isLazy) +} + +func TestBufLazyEmpty(t *testing.T) { + var buf LazyBuffer + eval(t, &buf, []byte(nil), 0, true) + buf.Write(nil) + eval(t, &buf, []byte{}, 0, false) +} + +func TestBufLazyLazyMode(t *testing.T) { + var buf LazyBuffer + var data []byte = []byte("Hello World!") + buf.Set(data) + eval(t, &buf, data, len(data), true) + + // Next 1 byte + assert.EqualValues(t, buf.Next(1), data[:1]) + eval(t, &buf, data[1:], len(data)-1, true) + + // Next remaining byte + assert.EqualValues(t, buf.Next(len(data)-1), data[1:]) + eval(t, &buf, make([]byte, 0), 0, true) + + // Next 1 more byte + assert.EqualValues(t, buf.Next(1), data[:0]) + eval(t, &buf, make([]byte, 0), 0, true) + + // Reset the data, must be in lazy mode as the previous data is drained + buf.Set(data) + eval(t, &buf, data, len(data), true) + + // Next all byte + 1 + assert.EqualValues(t, buf.Next(len(data)+1), data) + eval(t, &buf, make([]byte, 0), 0, true) + + // Done + buf.Done() + eval(t, &buf, []byte(nil), 0, true) +} + +func TestBufLazyLazyToWriteMode(t *testing.T) { + var buf LazyBuffer + var data []byte = []byte("Hello World!") + buf.Set(data) + eval(t, &buf, data, len(data), true) + + // switch to write + buf.Set(data) + doubleData := append(data, data...) + eval(t, &buf, doubleData, len(doubleData), false) + + // append new data + doubleData = append(doubleData, data...) + buf.Set(data) + eval(t, &buf, doubleData, len(doubleData), false) + + buf.Done() + eval(t, &buf, []byte(nil), 0, true) +} + +func TestBufLazyWriteMode(t *testing.T) { + var buf LazyBuffer + var data []byte = []byte(strings.Repeat("A", defaultSize)) + + buf.Grow(defaultSize) + // fill up the default buffer + buf.Write(data) + eval(t, &buf, data, len(data), false) + + // grow the buffer + buf.Write(data[:1]) + doubleData := append(data, data...) + eval(t, &buf, doubleData[:len(data)+1], len(data)+1, false) + + // fill up the remaining buffer + buf.Write(data[1:]) + eval(t, &buf, doubleData, len(doubleData), false) + + // consume half of the data + assert.EqualValues(t, buf.Next(len(data)), data) + eval(t, &buf, data, len(data), false) + + // consume 1 byte + assert.EqualValues(t, buf.Next(1), data[:1]) + eval(t, &buf, data[1:], len(data)-1, false) + + // grow 1 byte, the data is copied to the beginning + // slide things down + buf.Grow(1) + eval(t, &buf, data[1:], len(data)-1, false) + + // Done + buf.Done() + eval(t, &buf, []byte(nil), 0, true) +} diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index d8f62d85f..d921bf998 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -7,6 +7,7 @@ package tls import ( + "bytes" "context" "crypto/cipher" "crypto/subtle" @@ -18,15 +19,13 @@ import ( "net" "sync" "time" - - "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" ) // Socket is a set of functions which manipulate the underlying file descriptor of a connection. type Socket interface { // Fd returns the underlying file descriptor. Fd() int - Flush() error + WriteTCP([]byte) (int, error) } // A Conn represents a secured connection. @@ -106,11 +105,12 @@ type Conn struct { // This can significantly optimize the memory usage, especially when the server connecting millions of clients // where most of them are idle. in, out halfConn - rawInput EMsgBuffer // raw input, starting with a record header - input *elastic.RingBuffer // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn - hand EMsgBuffer // handshake data waiting to be read - buffering bool // whether records are buffered in sendBuf - sendBuf *elastic.Buffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn + rawInput LazyBuffer // raw input, starting with a record header + input bytes.Reader // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn + data []byte // buffer to hold all decrypted data + hand LazyBuffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf LazyBuffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn // bytesSent counts the bytes of application data sent. // packetsSent counts packets. @@ -599,18 +599,12 @@ func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeader } func (c *Conn) readRecord() error { - if c.rawInput.Len() > recordHeaderLen { - return c.readRecordOrCCS(false) - } - return io.EOF + return c.readRecordOrCCS(false) } func (c *Conn) readChangeCipherSpec() error { - c.input.Reset() - if c.rawInput.Len() > recordHeaderLen { - return c.readRecordOrCCS(true) - } - return io.EOF + return c.readRecordOrCCS(true) + } // ktlsInBufPool pools the buffers used by ktlsReadRecord. @@ -642,37 +636,34 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { handshakeComplete := c.HandshakeComplete() var ( - typ recordType - data []byte - // record []byte - hdr []byte - n int - vers uint16 - err error + typ recordType + data []byte + record []byte + hdr []byte + n int + vers uint16 + err error ) if _, ok := c.in.cipher.(kTLSCipher); ok { - dataPtr := ktlsInBufPool.Get().(*[]byte) - data = *dataPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *dataPtr = data[:maxPlaintext] - ktlsInBufPool.Put(dataPtr) - }() + if c.rawInput.Len() < maxPlaintext { + c.rawInput.Extend(maxPlaintext - c.rawInput.Len()) + } + data = c.rawInput.Bytes()[:maxPlaintext] if typ, n, err = ktlsReadRecord(c.conn.(Socket).Fd(), data); err != nil { return err } - data = data[:n] + data = c.rawInput.Next(n) // TODO: process the data here instead of goto processMessage // && try to use ktlsReadRecord to write data directly into input // rather than copy it later. goto processMessage } + if c.rawInput.Len() <= recordHeaderLen { + return io.EOF + } + hdr = c.rawInput.Bytes() typ = recordType(hdr[0]) @@ -713,9 +704,8 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { } // Process message. - c.rawInput.DiscardWithoutDone(recordHeaderLen + n) - defer c.rawInput.DoneIfEmpty() - data, typ, err = c.in.decrypt(hdr[:recordHeaderLen+n]) + record = c.rawInput.Next(recordHeaderLen + n) + data, typ, err = c.in.decrypt(record) if err != nil { return c.in.setErrorLocked(c.sendAlert(err.(alert))) } @@ -799,13 +789,14 @@ processMessage: // Note that data is owned by c.rawInput, following the Next call above, // to avoid copying the plaintext. This is safe because c.rawInput is // not read from or written to until c.input is drained. - c.input.Write(data) + c.input.Reset(data) + c.data = data case recordTypeHandshake: if len(data) == 0 || expectChangeCipherSpec { return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } - c.hand.Write(data) + c.hand.Set(data) } return nil @@ -819,11 +810,7 @@ func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { c.sendAlert(alertUnexpectedMessage) return c.in.setErrorLocked(errors.New("tls: too many ignored records")) } - c.input.Reset() - if c.rawInput.Len() > recordHeaderLen { - return c.readRecordOrCCS(expectChangeCipherSpec) - } - return io.EOF + return c.readRecordOrCCS(expectChangeCipherSpec) } // sendAlert sends a TLS alert message. @@ -936,18 +923,18 @@ func (c *Conn) write(data []byte) (int, error) { return len(data), nil } - n, err := c.conn.Write(data) + n, err := c.conn.(Socket).WriteTCP(data) c.bytesSent += int64(n) return n, err } func (c *Conn) flush() (int, error) { - if c.sendBuf.IsEmpty() { + if c.sendBuf.Len() == 0 { return 0, nil } - n := c.sendBuf.Buffered() + n, err := c.conn.(Socket).WriteTCP(c.sendBuf.Bytes()) c.bytesSent += int64(n) - err := c.conn.(Socket).Flush() + c.sendBuf.Done() c.buffering = false return n, err } @@ -1046,7 +1033,7 @@ func (c *Conn) readHandshake() (interface{}, error) { } } - data := c.hand.Peek(4) + data := c.hand.Bytes() n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if n > maxHandshake { c.sendAlertLocked(alertInternalError) @@ -1057,8 +1044,9 @@ func (c *Conn) readHandshake() (interface{}, error) { return nil, err } } - data = c.hand.Peek(4 + n) - defer c.hand.Discard(4 + n) + data = c.hand.Next(4 + n) + // Handshake messages are all processed, return the buffer to the pool + defer c.hand.Done() var m handshakeMessage switch data[0] { case typeHelloRequest: @@ -1173,26 +1161,66 @@ func (c *Conn) Write(b []byte) (int, error) { return n + m, c.out.setErrorLocked(err) } -// load the data into the TLS rawInput -func (c *Conn) RawWrite(data []byte) (int, error) { +// check whether the data is a complete TLS record +func (c *Conn) IsRecordCompleted(data []byte) bool { + if len(data) < recordHeaderLen { + return false + } + if len(data) < recordHeaderLen+int(data[3])<<8|int(data[4]) { + return false + } + return true +} - c.rawInput.Write(data) +// load the data into the TLS rawInput +// If rawInput is lazy and empty, the data is loaded immediately +// as a reference (zero-copy) +func (c *Conn) RawInputSet(data []byte) (int, error) { + c.rawInput.Set(data) return len(data), nil } -// Decrypt one tls record and save it in the 解析一条tls数据 +// Decrypt one tls record and save it in c.input which is +// owned by c.rawInput func (c *Conn) ReadFrame() error { - _, ok := c.in.cipher.(kTLSCipher) - if c.rawInput.Len() > recordHeaderLen || ok { - return c.readRecordOrCCS(false) - } - return io.EOF + return c.readRecordOrCCS(false) } -func (c *Conn) RawData() []byte { +// return all rawInput data +func (c *Conn) RawInputData() []byte { return c.rawInput.Bytes() } +// Clean up all decrypted data +// rawInput is cleaned up if all rawInput are processed. +// otherwise raw data is cached in order to make sure the +// is not owned by anyone else +func (c *Conn) DataDone() { + if c.rawInput.Len() == 0 { + // raw input is drain, thus clean it up + c.rawInput.Done() + } else { + // raw data has a few bytes left but not sufficient, so we cache it + // to make sure the data is not owned by anyone else + c.rawInput.Write(nil) + } + c.input.Reset(nil) + c.data = nil +} + +// call this function after close. so allocated memory for rawInput and hand +// are returned to the pool +func (c *Conn) DataCleanUpAfterClose() { + c.data = nil + c.input.Reset(nil) + c.rawInput.Done() +} + +// Return all decrypted data +func (c *Conn) Data() []byte { + return c.data +} + // handleRenegotiation processes a HelloRequest handshake message. func (c *Conn) handleRenegotiation() error { if c.vers == VersionTLS13 { diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index 0dd65b0c2..c20aa3ab1 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -24,52 +24,16 @@ import ( "net" "os" "strings" - - "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" ) -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func ServerGnet(conn net.Conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: config, - input: in, - sendBuf: out, - } - c.handshakeFn = c.serverHandshake - return c -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func ClientGnet(conn net.Conn, in *elastic.RingBuffer, out *elastic.Buffer, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: config, - input: in, - sendBuf: out, - isClient: true, - } - c.handshakeFn = c.clientHandshake - return c -} - // Server returns a new TLS server side connection // using conn as the underlying transport. // The configuration config must be non-nil and must include // at least one certificate or else set GetCertificate. func Server(conn net.Conn, config *Config) *Conn { - sendBuf, _ := elastic.New(65536) c := &Conn{ - conn: conn, - config: config, - input: new(elastic.RingBuffer), - sendBuf: sendBuf, + conn: conn, + config: config, } c.handshakeFn = c.serverHandshake return c @@ -80,12 +44,9 @@ func Server(conn net.Conn, config *Config) *Conn { // The config cannot be nil: users must set either ServerName or // InsecureSkipVerify in the config. func Client(conn net.Conn, config *Config) *Conn { - sendBuf, _ := elastic.New(65536) c := &Conn{ conn: conn, config: config, - input: new(elastic.RingBuffer), - sendBuf: sendBuf, isClient: true, } c.handshakeFn = c.clientHandshake From d24fd002261547be9ee1b91d87337494cef5410e Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 30 Jan 2023 17:46:16 +0000 Subject: [PATCH 19/34] bug: Fix unix.EAGAIN error returned by TLS read. --- eventloop.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/eventloop.go b/eventloop.go index 84876ec11..51c2c242b 100644 --- a/eventloop.go +++ b/eventloop.go @@ -120,6 +120,12 @@ func (el *eventloop) readTLS(c *conn) error { // TLS records are available for { if err := c.tlsconn.ReadFrame(); err != nil { + // If err is io.EOF, it can either the data is drained, + // receives a close notify from the client. + if err == unix.EAGAIN { + c.tlsconn.DataDone() + return nil + } return el.closeConn(c, os.NewSyscallError("TLS read", err)) } From 3f215222ce5a59ba2bdcef930cc98a7f58c2d655 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 30 Jan 2023 20:28:04 +0000 Subject: [PATCH 20/34] opt: replace syscall (deprecated golang library) with golang.org/x/sys/unix --- pkg/tls/ktls_linux.go | 94 +++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/pkg/tls/ktls_linux.go b/pkg/tls/ktls_linux.go index f26ac0e96..85d051860 100644 --- a/pkg/tls/ktls_linux.go +++ b/pkg/tls/ktls_linux.go @@ -10,7 +10,6 @@ import ( "os" "strconv" "strings" - "syscall" "unsafe" "golang.org/x/sys/unix" @@ -63,14 +62,14 @@ func init() { return } kTLSSupport = true && kTLSEnabled - Debugf("kTLS Enabled Status: %v\n", kTLSSupport) + Debugf("kTLS Enabled Status: %v", kTLSSupport) // no need to check further, as KTLS is disabled if !kTLSSupport { return } - var uname syscall.Utsname - if err := syscall.Uname(&uname); err != nil { + var uname unix.Utsname + if err := unix.Uname(&uname); err != nil { Debugf("kTLS: call uname failed %v", err) return } @@ -97,7 +96,7 @@ func init() { return } - Debugf("Kernel Version: %s\n", release) + Debugf("Kernel Version: %s", release) if (major == 4 && minor >= 13) || major > 4 { kTLSSupportTX = true @@ -127,17 +126,17 @@ func init() { } Debugln("======Supported Features======") - Debugf("kTLS TX: %v\n", kTLSSupportTX) - Debugf("kTLS RX: %v\n", kTLSSupportRX) - Debugf("kTLS TLS 1.3 TX: %v\n", kTLSSupportTLS13TX) - Debugf("kTLS TLS 1.3 RX: %v\n", kTLSSupportTLS13RX) - Debugf("kTLS TX ZeroCopy: %v\n", kTLSSupportZEROCOPY) - Debugf("kTLS RX Expected No Pad: %v\n", kTLSSupportNOPAD) + Debugf("kTLS TX: %v", kTLSSupportTX) + Debugf("kTLS RX: %v", kTLSSupportRX) + Debugf("kTLS TLS 1.3 TX: %v", kTLSSupportTLS13TX) + Debugf("kTLS TLS 1.3 RX: %v", kTLSSupportTLS13RX) + Debugf("kTLS TX ZeroCopy: %v", kTLSSupportZEROCOPY) + Debugf("kTLS RX Expected No Pad: %v", kTLSSupportNOPAD) Debugln("=========CipherSuites=========") - Debugf("kTLS AES-GCM-128: %v\n", kTLSSupportAESGCM128) - Debugf("kTLS AES-GCM-256: %v\n", kTLSSupportAESGCM256) - Debugf("kTLS CHACHA20POLY1305: %v\n", kTLSSupportCHACHA20POLY1305) + Debugf("kTLS AES-GCM-128: %v", kTLSSupportAESGCM128) + Debugf("kTLS AES-GCM-256: %v", kTLSSupportAESGCM256) + Debugf("kTLS CHACHA20POLY1305: %v", kTLSSupportCHACHA20POLY1305) } func (c *Conn) ReadFrom(r io.Reader) (n int64, err error) { @@ -171,12 +170,12 @@ func (c *Conn) writeToFile(f *os.File, remain int64) (written int64, err error, // mmap must align on a page boundary // mmap from 0, use data from offset - bytes, err := syscall.Mmap(int(f.Fd()), 0, int(offset+remain), - syscall.PROT_WRITE, syscall.MAP_SHARED) + bytes, err := unix.Mmap(int(f.Fd()), 0, int(offset+remain), + unix.PROT_WRITE, unix.MAP_SHARED) if err != nil { return 0, nil, false } - defer syscall.Munmap(bytes) + defer unix.Munmap(bytes) bytes = bytes[offset : offset+remain] var ( @@ -245,7 +244,7 @@ func (c *Conn) spliceToFile(f *os.File, remain int64) (written int64, err error, n, err = unix.Splice(int(rfd), nil, pwfd, nil, int(n), unix.SPLICE_F_MORE) remain -= n written += n - if err == syscall.EAGAIN { + if err == unix.EAGAIN { // return false to wait data from connection err = nil return false @@ -286,8 +285,8 @@ func (c *Conn) spliceToFile(f *os.File, remain int64) (written int64, err error, // destroyTempPipe destroys a temporary pipe. func destroyTempPipe(prfd, pwfd int) error { - err := syscall.Close(prfd) - err1 := syscall.Close(pwfd) + err := unix.Close(prfd) + err1 := unix.Close(pwfd) if err == nil { return err1 } @@ -382,26 +381,25 @@ func (c *Conn) enableKernelTLS(cipherSuiteID uint16, inKey, outKey, inIV, outIV return nil } -func ktlsReadRecord(fd int, b []byte) (recordType, int, error) { +func ktlsReadRecord(fd int, b []byte) (typ recordType, n int, err error) { // cmsg for record type - buffer := make([]byte, syscall.CmsgSpace(1)) - cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&buffer[0])) - cmsg.SetLen(syscall.CmsgLen(1)) + buffer := make([]byte, unix.CmsgSpace(1)) + cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) + cmsg.SetLen(unix.CmsgLen(1)) - var iov syscall.Iovec + var iov unix.Iovec iov.Base = &b[0] iov.SetLen(len(b)) - var msg syscall.Msghdr + var msg unix.Msghdr msg.Control = &buffer[0] msg.Controllen = cmsg.Len msg.Iov = &iov msg.Iovlen = 1 - var n int flags := 0 - n, err := recvmsg(uintptr(fd), &msg, flags) - if err == syscall.EAGAIN { + n, err = recvmsg(uintptr(fd), &msg, flags) + if err == unix.EAGAIN { // data is not ready, goroutine will be parked return 0, n, err } @@ -433,7 +431,7 @@ func ktlsReadRecord(fd int, b []byte) (recordType, int, error) { Debugf("kTLS: unsupported cmsg type: %d", cmsg.Type) return 0, 0, fmt.Errorf("unsupported cmsg type: %d", cmsg.Type) } - typ := recordType(buffer[syscall.SizeofCmsghdr]) + typ = recordType(buffer[unix.SizeofCmsghdr]) Debugf("kTLS: recvmsg, type: %d, payload len: %d", typ, n) return typ, n, nil } @@ -459,8 +457,8 @@ func ktlsReadDataFromRecord(fd int, b []byte) (int, error) { } } -func recvmsg(fd uintptr, msg *syscall.Msghdr, flags int) (n int, err error) { - r0, _, e1 := syscall.Syscall(syscall.SYS_RECVMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) +func recvmsg(fd uintptr, msg *unix.Msghdr, flags int) (n int, err error) { + r0, _, e1 := unix.Syscall(unix.SYS_RECVMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) n = int(r0) if e1 != 0 { err = errnoErr(e1) @@ -468,8 +466,8 @@ func recvmsg(fd uintptr, msg *syscall.Msghdr, flags int) (n int, err error) { return } -func sendmsg(fd uintptr, msg *syscall.Msghdr, flags int) (n int, err error) { - r0, _, e1 := syscall.Syscall(syscall.SYS_SENDMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) +func sendmsg(fd uintptr, msg *unix.Msghdr, flags int) (n int, err error) { + r0, _, e1 := unix.Syscall(unix.SYS_SENDMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) n = int(r0) if e1 != 0 { err = errnoErr(e1) @@ -480,22 +478,22 @@ func sendmsg(fd uintptr, msg *syscall.Msghdr, flags int) (n int, err error) { // Do the interface allocations only once for common // Errno values. var ( - errEAGAIN error = syscall.EAGAIN - errEINVAL error = syscall.EINVAL - errENOENT error = syscall.ENOENT + errEAGAIN error = unix.EAGAIN + errEINVAL error = unix.EINVAL + errENOENT error = unix.ENOENT ) // errnoErr returns common boxed Errno values, to prevent // allocations at runtime. -func errnoErr(e syscall.Errno) error { +func errnoErr(e unix.Errno) error { switch e { case 0: return nil - case syscall.EAGAIN: + case unix.EAGAIN: return errEAGAIN - case syscall.EINVAL: + case unix.EINVAL: return errEINVAL - case syscall.ENOENT: + case unix.ENOENT: return errENOENT } return e @@ -503,18 +501,18 @@ func errnoErr(e syscall.Errno) error { func ktlsSendCtrlMessage(fd int, typ recordType, b []byte) (int, error) { // cmsg for record type - buffer := make([]byte, syscall.CmsgSpace(1)) - cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&buffer[0])) - cmsg.SetLen(syscall.CmsgLen(1)) - buffer[syscall.SizeofCmsghdr] = byte(typ) + buffer := make([]byte, unix.CmsgSpace(1)) + cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) + cmsg.SetLen(unix.CmsgLen(1)) + buffer[unix.SizeofCmsghdr] = byte(typ) cmsg.Level = SOL_TLS cmsg.Type = TLS_SET_RECORD_TYPE - var iov syscall.Iovec + var iov unix.Iovec iov.Base = &b[0] iov.SetLen(len(b)) - var msg syscall.Msghdr + var msg unix.Msghdr msg.Control = &buffer[0] msg.Controllen = cmsg.Len msg.Iov = &iov @@ -523,7 +521,7 @@ func ktlsSendCtrlMessage(fd int, typ recordType, b []byte) (int, error) { var n int flags := 0 n, err := sendmsg(uintptr(fd), &msg, flags) - if err == syscall.EAGAIN { + if err == unix.EAGAIN { // data is not ready, goroutine will be parked return n, err } From d13ead19c7c6eb33296e188be150d174596a408c Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Tue, 31 Jan 2023 02:52:11 +0000 Subject: [PATCH 21/34] opt: remove gnetConn.tlsEnabled & update the doc related to tlsconn --- connection.go | 25 ++++++------------------- eventloop.go | 41 ++++++++++------------------------------- 2 files changed, 16 insertions(+), 50 deletions(-) diff --git a/connection.go b/connection.go index abbd6acee..84ad64041 100644 --- a/connection.go +++ b/connection.go @@ -50,7 +50,6 @@ type conn struct { isDatagram bool // UDP protocol opened bool // connection opened event fired tlsconn *tls.Conn // tls connection - // tlsEnabled bool // whether TLS is enabled } func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, localAddr, remoteAddr net.Addr) (c *conn) { @@ -84,7 +83,6 @@ func (c *conn) releaseTCP() { c.outboundBuffer.Release() netpoll.PutPollAttachment(c.pollAttachment) c.pollAttachment = nil - // c.tlsEnabled = false } func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, connected bool) (c *conn) { @@ -132,15 +130,11 @@ func (c *conn) open(buf []byte) error { } func (c *conn) writeTLS(data []byte) (n int, err error) { - // temporarily disable the TLS connection. - // c.tlsEnabled = false // use tls to encrypt the data before sending it. - // tlsconn will implicitly call gnet.Write (but it runs the plaintext version) to sent the data. - // unsent data will be buffered + // tlsconn will call gnet.WriteTCP() to sent the data directly. + // If gnetConn.outboundBufferis not empty, data will be + // buffered in gnetConn.outboundBuffer. n, err = c.tlsconn.Write(data) - // re-enable the TLS connection if data is sent or is buffered - // Otherwise, the connection is closed (c.loop.closeConn() is called). - // c.tlsEnabled = true return } @@ -177,11 +171,10 @@ func (c *conn) writevTLS(bs [][]byte) (n int, err error) { n += len(b) } - // temporarily disable the TLS connection. - // c.tlsEnabled = false // use tls to encrypt the data before sending it. - // tlsconn will implicitly call gnet.Write (but it runs the plaintext version) to sent the data. - // unsent data will be buffered + // tlsconn will call gnet.WriteTCP() to sent the data directly. + // If gnetConn.outboundBufferis not empty, data will be + // buffered in gnetConn.outboundBuffer. sent := 0 var sentN int for _, b := range bs { @@ -192,11 +185,6 @@ func (c *conn) writevTLS(bs [][]byte) (n int, err error) { } sent += sentN } - - // re-enable the TLS connection if data is sent or is buffered - // Otherwise, the connection is closed (c.loop.closeConn() is called). - // c.tlsEnabled = true - return } @@ -537,7 +525,6 @@ func (c *conn) Close() error { func (c *conn) UpgradeTLS(config *tls.Config) (err error) { // TODO: create a sync.pool to manage the TLS connection c.tlsconn = tls.Server(c, config.Clone()) - // c.tlsEnabled = true //很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 if c.inboundBuffer.Len() > 0 { diff --git a/eventloop.go b/eventloop.go index 51c2c242b..882c81457 100644 --- a/eventloop.go +++ b/eventloop.go @@ -120,12 +120,13 @@ func (el *eventloop) readTLS(c *conn) error { // TLS records are available for { if err := c.tlsconn.ReadFrame(); err != nil { - // If err is io.EOF, it can either the data is drained, - // receives a close notify from the client. + // Receive error unix.EAGAIN, wait for the next round if err == unix.EAGAIN { c.tlsconn.DataDone() return nil } + // If err is io.EOF, it can either the data is drained, + // receives a close notify from the client. return el.closeConn(c, os.NewSyscallError("TLS read", err)) } @@ -155,11 +156,11 @@ func (el *eventloop) readTLS(c *conn) error { func (el *eventloop) read(c *conn) error { // detected whether kernel TLS RX is enabled // This only happens after TLS handshake is completed. - // Therefore, no need to call c.tlsconn.HandshakeComplete() - // In addition, all data are copied directly from kernel to the buffer, - // el.buffer meaning no need to call unix.read(c.fd, el.buffer) + // Therefore, no need to call c.tlsconn.HandshakeComplete(). if c.tlsconn != nil && c.tlsconn.IsKTLSRXEnabled() { - // attach the gnet eventloop.buffer to tlsconn.rawInput + // attach the gnet eventloop.buffer to tlsconn.rawInput. + // So, KTLS can decrypt the data directly to the buffer without memory allocation. + // Since data is read through KTLS, there is no need to call unix.read(c.fd, el.buffer) c.tlsconn.RawInputSet(el.buffer) return el.readTLS(c) } @@ -176,6 +177,7 @@ func (el *eventloop) read(c *conn) error { } if c.tlsconn != nil { + // attach the gnet eventloop.buffer to tlsconn.rawInput. c.tlsconn.RawInputSet(el.buffer[:n]) if !c.tlsconn.HandshakeComplete() { //先判断是否足够一条消息 @@ -186,7 +188,7 @@ func (el *eventloop) read(c *conn) error { } if err = c.tlsconn.Handshake(); err != nil { // closeConn will cleanup the TLS data at the end, - // so need to call tlsconn.DataDone() + // so no need to call tlsconn.DataDone() return el.closeConn(c, os.NewSyscallError("TLS handshake", err)) } if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawInputData()) == 0 { //握手没成功,或者握手成功,但是没有数据黏包了 @@ -264,32 +266,9 @@ func (el *eventloop) closeConn(c *conn, err error) (rerr error) { } // close the TLS connection by sending the alert - // - // tlsconn.Close() is called only if err == nil. - // Notice tlsconn.Close() eventually calls gnetConn.write(). - // gnetConn.write() is possible to call c.loop.closeConn() again if unix.Write() returns an error other than unix.EAGAIN, - // which creates a cycle. The error message could be "broken pipe" or "connection reset by peer" indicating - // the socket is no longer valid. - // - // This implies that once "err is not nil", it will create a cycle, and run the code in an infinite loop. - // Therefore, we use "if err == nil" to detect the cycle and break it. if c.tlsconn != nil { - // The default call graph results calling gnet.writeTLS(). See below - // - // tlsconn.Close() -> tlsconn.sendAlertLocked() -> tlsconn.writeRecordLocked() -> tlsconn.write() -> - // gnetConn.Write() (Here tlsEnabled is true, implying to run the TLS version) -> - // gnetConn.writeTLS() -> tlsconn.Write() -> tlsconn.writeRecordLocked() -> tlsconn.write() -> - // gnetConn.Write() (Here tlsEnabled is false, implying to run the plaintext version) -> - // gnetConn.write() - // - // Therefore, the closing message is encrypted twice which is not correct. - // To resolve the issue, we disable the TLS before closing. Then, when tlsconn.Close() - // calls gnetConn.Write(), which immediately runs the plaintext version. This means - // the encrypted data is written to the socket directly instead of being encrypted one more time. - // c.tlsEnabled = false - // if err == nil { + // close the TLS connection, which will send a close notify to the client c.tlsconn.Close() - // } // Make sure all memory requested from the pool is returned. c.tlsconn.DataCleanUpAfterClose() c.tlsconn = nil From b1b7bc5184093847fb478196525cb78b6964a293 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Tue, 31 Jan 2023 03:06:10 +0000 Subject: [PATCH 22/34] opt: remove unused MsgBuffer & EMsgBuffer --- pkg/tls/buf.go | 167 ------------------------------------------ pkg/tls/bufElastic.go | 156 --------------------------------------- pkg/tls/bufLazy.go | 2 + 3 files changed, 2 insertions(+), 323 deletions(-) delete mode 100644 pkg/tls/buf.go delete mode 100644 pkg/tls/bufElastic.go diff --git a/pkg/tls/buf.go b/pkg/tls/buf.go deleted file mode 100644 index 0fafc4172..000000000 --- a/pkg/tls/buf.go +++ /dev/null @@ -1,167 +0,0 @@ -package tls - -import ( - "io" - "unsafe" -) - -type MsgBuffer struct { - b []byte - l int // Total length of buffered data - i int // Position of unread buffered data -} - -const ( - blockSize = 8192 // clean up the data when i >= blocksize - appendSize = 4096 - defaultSize = 4096 -) - -// New returns a new MsgBuffer whose buffer has the given size. -func NewMsgBuffer(n int) *MsgBuffer { - return &MsgBuffer{b: make([]byte, 0, n)} -} - -func (w *MsgBuffer) Reset() { - w.l = 0 - w.i = 0 -} - -// clean up the data when i >= blockSize -func (w *MsgBuffer) clean() { - if w.i >= blockSize { - copy(w.b[:w.l-w.i], w.b[w.i:w.l]) - w.l -= w.i - w.i = 0 - } -} - -// grow the buffer size if the size of current buffer cannot fit the new incoming data. -func (w *MsgBuffer) grow() { - if len(w.b) < w.l { - if cap(w.b) < w.l { - add := w.l - len(w.b) - if add > appendSize { - w.b = append(w.b, make([]byte, add)...) - } else { - w.b = append(w.b, make([]byte, appendSize)...) - } - } - w.b = w.b[:w.l] - } -} - -func (w *MsgBuffer) Make(l int) []byte { - w.clean() - o := w.l - w.l += l - w.grow() - return w.b[o:w.l] -} - -func (w *MsgBuffer) Write(b []byte) (int, error) { - w.clean() - l := len(b) - o := w.l - w.l += l - w.grow() - copy(w.b[o:w.l], b) - return l, nil -} - -func (w *MsgBuffer) WriteString(s string) { - w.clean() - x := (*[2]uintptr)(unsafe.Pointer(&s)) - h := [3]uintptr{x[0], x[1], x[1]} - b := *(*[]byte)(unsafe.Pointer(&h)) - l := len(b) - o := w.l - w.l += l - w.grow() - copy(w.b[o:w.l], b) -} - -func (w *MsgBuffer) WriteByte(s byte) error { - w.clean() - w.l++ - w.grow() - w.b[w.l-1] = s - - return nil -} - -func (w *MsgBuffer) Bytes() []byte { - return w.b[w.i:w.l] -} - -func (w *MsgBuffer) Peek(n int) []byte { - end := w.i + n - if end > w.l { - end = w.l - } - return w.b[w.i:end] -} - -func (w *MsgBuffer) Len() int { - return w.l - w.i -} - -func (w *MsgBuffer) Truncate(i int) { - l := w.i + i - if l < w.l { - w.l = l - } -} - -func (w *MsgBuffer) String() string { - b := make([]byte, w.l-w.i) - copy(b, w.b[w.i:w.l]) - return *(*string)(unsafe.Pointer(&b)) -} - -// Discard skips the next n bytes by advancing the read pointer. -func (r *MsgBuffer) Discard(l int) { - if l <= 0 { - return - } - if l < r.Len() { - r.i += l - } else { - r.Reset() - } -} - -func (r *MsgBuffer) Close() error { - return nil -} - -func (r *MsgBuffer) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - if r.i == r.l { - return 0, io.EOF - } - o := r.i - r.i += len(p) - if r.i > r.l { - r.i = r.l - } - copy(p, r.b[o:r.i]) - return r.i - o, nil -} - -// ReadByte reads and returns the next byte from the input or ErrIsEmpty. -func (r *MsgBuffer) ReadByte() (b byte, err error) { - if r.i == r.l { - return 0, io.EOF - } - b = r.b[r.i] - r.i++ - return b, err -} - -// IsEmpty tells if this MsgBuffer is empty. -func (b *MsgBuffer) IsEmpty() bool { - return b.i == b.l -} diff --git a/pkg/tls/bufElastic.go b/pkg/tls/bufElastic.go deleted file mode 100644 index 2d0d1c17b..000000000 --- a/pkg/tls/bufElastic.go +++ /dev/null @@ -1,156 +0,0 @@ -package tls - -import ( - "io" - "sync" -) - -// EMsgBuffer is the elastic wrapper of EMsgBuffer. -type EMsgBuffer struct { - mb *MsgBuffer -} - -var msgBufferPool = sync.Pool{ - New: func() any { - return NewMsgBuffer(defaultSize) - }, -} - -func (b *EMsgBuffer) instance() *MsgBuffer { - if b.mb == nil { - b.mb = msgBufferPool.New().(*MsgBuffer) - } - return b.mb -} - -// Done checks and returns the internal MsgBuffer to pool. -func (b *EMsgBuffer) Done() { - if b.mb != nil { - b.mb.Reset() - msgBufferPool.Put(b.mb) - b.mb = nil - } -} - -func (b *EMsgBuffer) DoneIfEmpty() { - b.done() -} - -func (b *EMsgBuffer) done() { - if b.mb != nil && b.mb.IsEmpty() { - b.mb.Reset() - msgBufferPool.Put(b.mb) - b.mb = nil - } -} - -func (b *EMsgBuffer) Make(l int) []byte { - return b.instance().Make(l) -} - -// Write writes len(p) bytes from p to the underlying buf. -func (b *EMsgBuffer) Write(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - return b.instance().Write(p) -} - -// WriteString writes the contents of the string s to buffer, which accepts a slice of bytes. -func (b *EMsgBuffer) WriteString(s string) { - b.instance().WriteString(s) -} - -// WriteByte writes one byte into buffer. -func (b *EMsgBuffer) WriteByte(c byte) error { - return b.instance().WriteByte(c) -} - -// Bytes returns all available read bytes. It does not move the read pointer and only copy the available data. -func (b *EMsgBuffer) Bytes() []byte { - if b.mb == nil { - return nil - } - return b.mb.Bytes() -} - -// Bytes returns first n readable bytes. It does not move the read pointer and only copy the available data. -func (b *EMsgBuffer) Peek(n int) []byte { - if b.mb == nil { - return nil - } - return b.mb.Peek(n) -} - -// Len returns the length of the underlying buffer. -func (b *EMsgBuffer) Len() int { - if b.mb == nil { - return 0 - } - return b.mb.Len() -} - -// truncate the total number of readable bytes to i -func (b *EMsgBuffer) Truncate(i int) { - if b.mb != nil { - b.mb.Truncate(i) - b.done() - } -} - -func (b *EMsgBuffer) String() string { - if b.mb == nil { - return "" - } - return b.mb.String() -} - -// Discard skips the next n bytes by advancing the read pointer. -func (b *EMsgBuffer) Discard(l int) { - if b.mb != nil { - b.mb.Discard(l) - b.done() - } -} - -// Discard skips the next n bytes by advancing the read pointer, but holding the MsgBuffer temporarily. -// Doing so can ensure one can use the data return by Peek not used by another thread. -// Therefore, thread-safe is guaranteed. -func (b *EMsgBuffer) DiscardWithoutDone(l int) { - if b.mb != nil { - b.mb.Discard(l) - } -} - -func (b *EMsgBuffer) Close() error { - if b.mb == nil { - return nil - } - return b.Close() - -} - -func (b *EMsgBuffer) Read(p []byte) (n int, err error) { - if b.mb == nil { - return 0, io.EOF - } - defer b.done() - return b.mb.Read(p) -} - -// ReadByte reads and returns the next byte from the input or ErrIsEmpty. -func (b *EMsgBuffer) ReadByte() (byte, error) { - if b.mb == nil { - return 0, io.EOF - } - defer b.done() - return b.mb.ReadByte() -} - -// IsEmpty tells if this MsgBuffer is empty. -func (b *EMsgBuffer) IsEmpty() bool { - if b.mb == nil { - return true - } - return b.mb.IsEmpty() -} diff --git a/pkg/tls/bufLazy.go b/pkg/tls/bufLazy.go index 0ae104e2c..bfcda3411 100644 --- a/pkg/tls/bufLazy.go +++ b/pkg/tls/bufLazy.go @@ -4,6 +4,8 @@ import ( "sync" ) +const defaultSize = 4096 + var bytePool = sync.Pool{ New: func() any { buf := make([]byte, defaultSize) From e054d94c0785443677ed47722377b6f47fac1ba5 Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Wed, 14 Dec 2022 09:43:16 -0800 Subject: [PATCH 23/34] crypto/tls: replace all usages of BytesOrPanic Message marshalling makes use of BytesOrPanic a lot, under the assumption that it will never panic. This assumption was incorrect, and specifically crafted handshakes could trigger panics. Rather than just surgically replacing the usages of BytesOrPanic in paths that could panic, replace all usages of it with proper error returns in case there are other ways of triggering panics which we didn't find. In one specific case, the tree routed by expandLabel, we replace the usage of BytesOrPanic, but retain a panic. This function already explicitly panicked elsewhere, and returning an error from it becomes rather painful because it requires changing a large number of APIs. The marshalling is unlikely to ever panic, as the inputs are all either fixed length, or already limited to the sizes required. If it were to panic, it'd likely only be during development. A close inspection shows no paths for a user to cause a panic currently. This patches ends up being rather large, since it requires routing errors back through functions which previously had no error returns. Where possible I've tried to use helpers that reduce the verbosity of frequently repeated stanzas, and to make the diffs as minimal as possible. Thanks to Marten Seemann for reporting this issue. Updates #58001 Fixes #58359 Fixes CVE-2022-41724 Change-Id: Ieb55867ef0a3e1e867b33f09421932510cb58851 Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1679436 Reviewed-by: Julie Qiu TryBot-Result: Security TryBots Run-TryBot: Roland Shoemaker Reviewed-by: Damien Neil (cherry picked from commit 1d4e6ca9454f6cf81d30c5361146fb5988f1b5f6) Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728205 Reviewed-by: Tatiana Bradley Reviewed-on: https://go-review.googlesource.com/c/go/+/468121 Reviewed-by: Than McIntosh Auto-Submit: Michael Pratt TryBot-Bypass: Michael Pratt Run-TryBot: Michael Pratt --- pkg/tls/common.go | 2 +- pkg/tls/conn.go | 44 +- pkg/tls/handshake_client.go | 97 ++-- pkg/tls/handshake_client_tls13.go | 74 +-- pkg/tls/handshake_messages.go | 716 +++++++++++++++-------------- pkg/tls/handshake_messages_test.go | 19 +- pkg/tls/handshake_server.go | 73 +-- pkg/tls/handshake_server_tls13.go | 71 +-- pkg/tls/key_schedule.go | 19 +- pkg/tls/ticket.go | 8 +- 10 files changed, 630 insertions(+), 493 deletions(-) diff --git a/pkg/tls/common.go b/pkg/tls/common.go index 007f0f47b..5394d64ac 100644 --- a/pkg/tls/common.go +++ b/pkg/tls/common.go @@ -1394,7 +1394,7 @@ func (c *Certificate) leaf() (*x509.Certificate, error) { } type handshakeMessage interface { - marshal() []byte + marshal() ([]byte, error) unmarshal([]byte) bool } diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index d921bf998..9e1398c11 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -1018,15 +1018,32 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { return n, nil } -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { - return c.writeRecordLocked(typ, data) +// writeHandshakeRecord writes a handshake message to the connection and updates +// the record layer state. If transcript is non-nil the marshalled message is +// written to it. +func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { + data, err := msg.marshal() + if err != nil { + return 0, err + } + if transcript != nil { + transcript.Write(data) + } + + return c.writeRecordLocked(recordTypeHandshake, data) +} + +// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and +// updates the record layer state. +func (c *Conn) writeChangeCipherRecord() error { + _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) + return err } // readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (interface{}, error) { +// the record layer. If transcript is non-nil, the message +// is written to the passed transcriptHash. +func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { for c.hand.Len() < 4 { if err := c.readRecord(); err != nil { return nil, err @@ -1107,6 +1124,11 @@ func (c *Conn) readHandshake() (interface{}, error) { if !m.unmarshal(data) { return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + + if transcript != nil { + transcript.Write(data) + } + return m, nil } @@ -1227,7 +1249,7 @@ func (c *Conn) handleRenegotiation() error { return errors.New("tls: internal error: unexpected renegotiation") } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1273,7 +1295,7 @@ func (c *Conn) handlePostHandshakeMessage() error { return c.handleRenegotiation() } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1307,7 +1329,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { if keyUpdate.updateRequested { msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + msgBytes, err := msg.marshal() + if err != nil { + return err + } + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) if err != nil { // Surface the error at the next write. c.out.setErrorLocked(err) diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go index fbba66a50..8d8633eb1 100644 --- a/pkg/tls/handshake_client.go +++ b/pkg/tls/handshake_client.go @@ -167,7 +167,10 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } c.serverName = hello.serverName - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) + if err != nil { + return err + } if cacheKey != "" && session != nil { defer func() { // If we got a handshake failure when resuming a session, throw away @@ -182,7 +185,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { }() } - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(hello, nil); err != nil { return err } c.flush() @@ -202,7 +205,8 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { return nil case 1: hello := c.hs.(*clientHandshakeStateTLS13).hello - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -258,9 +262,9 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *ClientSessionState, earlySecret, binderKey []byte) { + session *ClientSessionState, earlySecret, binderKey []byte, err error) { if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil + return "", nil, nil, nil, nil } hello.ticketSupported = true @@ -275,14 +279,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // renegotiation is primarily used to allow a client to send a client // certificate, which would be skipped if session resumption occurred. if c.handshakes != 0 { - return "", nil, nil, nil + return "", nil, nil, nil, nil } // Try to resume a previously negotiated TLS session, if available. cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) session, ok := c.config.ClientSessionCache.Get(cacheKey) if !ok || session == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Check that version used for the previous session is still valid. @@ -294,7 +298,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !versOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Check that the cached server certificate is not expired, and that it's @@ -303,16 +307,16 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, if !c.config.InsecureSkipVerify { if len(session.verifiedChains) == 0 { // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } serverCert := session.serverCertificates[0] if c.config.time().After(serverCert.NotAfter) { // Expired certificate, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } } @@ -320,7 +324,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // In TLS 1.2 the cipher suite must match the resumed session. Ensure we // are still offering it. if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } hello.sessionTicket = session.sessionTicket @@ -330,14 +334,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // Check that the session ticket is not expired. if c.config.time().After(session.useBy) { c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // In TLS 1.3 the KDF hash must match the resumed session. Ensure we // offer at least one cipher suite with that hash. cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) if cipherSuite == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } cipherSuiteOk := false for _, offeredID := range hello.cipherSuites { @@ -348,7 +352,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !cipherSuiteOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. @@ -366,9 +370,15 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, earlySecret = cipherSuite.extract(psk, nil) binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) transcript := cipherSuite.hash.New() - transcript.Write(hello.marshalWithoutBinders()) + helloBytes, err := hello.marshalWithoutBinders() + if err != nil { + return "", nil, nil, nil, err + } + transcript.Write(helloBytes) pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - hello.updateBinders(pskBinders) + if err := hello.updateBinders(pskBinders); err != nil { + return "", nil, nil, nil, err + } return } @@ -415,8 +425,12 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.hello.marshal()) - hs.finishedHash.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { + return err + } + if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { + return err + } c.handshakeStatus = 3 c.buffering = true } @@ -510,7 +524,7 @@ func (hs *clientHandshakeState) pickCipherSuite() error { func (hs *clientHandshakeState) doFullHandshakeStep1() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -519,7 +533,8 @@ func (hs *clientHandshakeState) doFullHandshakeStep1() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) + + msg, err = c.readHandshake(&hs.finishedHash) if c.handshakes == 1 || len(c.peerCertificates) == 0 { // If this is the first handshake on a connection, process and // (optionally) verify the server's certificates. @@ -542,7 +557,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep1() error { } func (hs *clientHandshakeState) doFullHandshakeStep2() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -560,11 +575,10 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { c.sendAlert(alertUnexpectedMessage) return errors.New("tls: received unexpected CertificateStatus message") } - hs.finishedHash.Write(cs.marshal()) c.ocspResponse = cs.response - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -574,14 +588,13 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { skx, ok := msg.(*serverKeyExchangeMsg) if ok { - hs.finishedHash.Write(skx.marshal()) err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) if err != nil { c.sendAlert(alertUnexpectedMessage) return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -592,7 +605,6 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { certReq, ok := msg.(*certificateRequestMsg) if ok { certRequested = true - hs.finishedHash.Write(certReq.marshal()) cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { @@ -600,7 +612,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -611,7 +623,6 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(shd, msg) } - hs.finishedHash.Write(shd.marshal()) // If the server requested a certificate then we have to send a // Certificate message, even if it's empty because we don't have a @@ -619,8 +630,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { if certRequested { certMsg := new(certificateMsg) certMsg.certificates = chainToSend.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } } @@ -631,8 +641,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { return err } if ckx != nil { - hs.finishedHash.Write(ckx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { return err } } @@ -679,8 +688,7 @@ func (hs *clientHandshakeState) doFullHandshakeStep2() error { return err } - hs.finishedHash.Write(certVerify.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { return err } } @@ -817,7 +825,10 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -833,7 +844,11 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { c.sendAlert(alertHandshakeFailure) return errors.New("tls: server's Finished message was incorrect") } - hs.finishedHash.Write(serverFinished.marshal()) + + if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -844,7 +859,7 @@ func (hs *clientHandshakeState) readSessionTicket() error { } c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -853,7 +868,6 @@ func (hs *clientHandshakeState) readSessionTicket() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(sessionTicketMsg, msg) } - hs.finishedHash.Write(sessionTicketMsg.marshal()) hs.session = &ClientSessionState{ sessionTicket: sessionTicketMsg.ticket, @@ -873,14 +887,13 @@ func (hs *clientHandshakeState) readSessionTicket() error { func (hs *clientHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } copy(out, finished.verifyData) diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go index 52a75954e..dffe17bab 100644 --- a/pkg/tls/handshake_client_tls13.go +++ b/pkg/tls/handshake_client_tls13.go @@ -62,7 +62,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } hs.transcript = hs.suite.hash.New() - hs.transcript.Write(hs.hello.marshal()) + + if err := transcriptMsg(hs.hello, hs.transcript); err != nil { + return err + } if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { if err := hs.sendDummyChangeCipherSpec(); err != nil { @@ -73,7 +76,9 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } } - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } c.buffering = true if err := hs.processServerHello(); err != nil { @@ -176,8 +181,7 @@ func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } // processHelloRetryRequest handles the HRR in hs.serverHello, modifies and @@ -192,7 +196,9 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) hs.transcript.Write(chHash) - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } // The only HelloRetryRequest extensions we support are key_share and // cookie, and clients must abort the handshake if the HRR would not result @@ -257,10 +263,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { transcript := hs.suite.hash.New() transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) transcript.Write(chHash) - transcript.Write(hs.serverHello.marshal()) - transcript.Write(hs.hello.marshalWithoutBinders()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } + helloBytes, err := hs.hello.marshalWithoutBinders() + if err != nil { + return err + } + transcript.Write(helloBytes) pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - hs.hello.updateBinders(pskBinders) + if err := hs.hello.updateBinders(pskBinders); err != nil { + return err + } } else { // Server selected a cipher suite incompatible with the PSK. hs.hello.pskIdentities = nil @@ -268,12 +282,12 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { } } - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -367,6 +381,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { if !hs.usingPSK { earlySecret = hs.suite.extract(nil, nil) } + handshakeSecret := hs.suite.extract(sharedKey, hs.suite.deriveSecret(earlySecret, "derived", nil)) @@ -397,7 +412,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { func (hs *clientHandshakeStateTLS13) readServerParameters() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -407,7 +422,6 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(encryptedExtensions, msg) } - hs.transcript.Write(encryptedExtensions.marshal()) if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { c.sendAlert(alertUnsupportedExtension) @@ -436,18 +450,16 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return nil } - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } certReq, ok := msg.(*certificateRequestMsgTLS13) if ok { - hs.transcript.Write(certReq.marshal()) - hs.certReq = certReq - msg, err = c.readHandshake() + msg, err = c.readHandshake(hs.transcript) if err != nil { return err } @@ -462,7 +474,6 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { c.sendAlert(alertDecodeError) return errors.New("tls: received empty certificates message") } - hs.transcript.Write(certMsg.marshal()) c.scts = certMsg.certificate.SignedCertificateTimestamps c.ocspResponse = certMsg.certificate.OCSPStaple @@ -471,7 +482,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return err } - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -502,7 +516,9 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return errors.New("tls: invalid signature by the server certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } return nil } @@ -510,7 +526,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { func (hs *clientHandshakeStateTLS13) readServerFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -527,7 +546,9 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { return errors.New("tls: invalid server finished hash") } - hs.transcript.Write(finished.marshal()) + if err := transcriptMsg(finished, hs.transcript); err != nil { + return err + } // Derive secrets that take context through the server Finished. @@ -575,8 +596,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -613,8 +633,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -628,8 +647,7 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } diff --git a/pkg/tls/handshake_messages.go b/pkg/tls/handshake_messages.go index 7ab0f100b..695aacf12 100644 --- a/pkg/tls/handshake_messages.go +++ b/pkg/tls/handshake_messages.go @@ -5,6 +5,7 @@ package tls import ( + "errors" "fmt" "strings" @@ -94,9 +95,181 @@ type clientHelloMsg struct { pskBinders [][]byte } -func (m *clientHelloMsg) marshal() []byte { +func (m *clientHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + exts.AddUint16(extensionServerName) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(0) // name_type = host_name + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + exts.AddUint16(extensionStatusRequest) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(1) // status_type = ocsp + exts.AddUint16(0) // empty responder_id_list + exts.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + exts.AddUint16(extensionSupportedCurves) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + exts.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + exts.AddUint16(extensionSessionTicket) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + exts.AddUint16(extensionSignatureAlgorithms) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + exts.AddUint16(extensionSignatureAlgorithmsCert) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + exts.AddUint16(extensionSCT) + exts.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + exts.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, ks := range m.keyShares { + exts.AddUint16(uint16(ks.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + exts.AddUint16(extensionEarlyData) + exts.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + exts.AddUint16(extensionPSKModes) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.pskModes) + }) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(psk.label) + }) + exts.AddUint32(psk.obfuscatedTicketAge) + } + }) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(binder) + }) + } + }) + }) + } + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -116,219 +289,53 @@ func (m *clientHelloMsg) marshal() []byte { b.AddBytes(m.compressionMethods) }) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - b.AddUint16(extensionServerName) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // name_type = host_name - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(1) // status_type = ocsp - b.AddUint16(0) // empty responder_id_list - b.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - b.AddUint16(extensionSupportedCurves) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - b.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - b.AddUint16(extensionSessionTicket) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - b.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ks := range m.keyShares { - b.AddUint16(uint16(ks.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - b.AddUint16(extensionPSKModes) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.pskModes) - }) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(psk.label) - }) - b.AddUint32(psk.obfuscatedTicketAge) - } - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } // marshalWithoutBinders returns the ClientHello through the // PreSharedKeyExtension.identities field, according to RFC 8446, Section // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() []byte { +func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { bindersLen := 2 // uint16 length prefix for _, binder := range m.pskBinders { bindersLen += 1 // uint8 length prefix bindersLen += len(binder) } - fullMessage := m.marshal() - return fullMessage[:len(fullMessage)-bindersLen] + fullMessage, err := m.marshal() + if err != nil { + return nil, err + } + return fullMessage[:len(fullMessage)-bindersLen], nil } // updateBinders updates the m.pskBinders field, if necessary updating the // cached marshaled representation. The supplied binders must have the same // length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { if len(pskBinders) != len(m.pskBinders) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } for i := range m.pskBinders { if len(pskBinders[i]) != len(m.pskBinders[i]) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } } m.pskBinders = pskBinders if m.raw != nil { - lenWithoutBinders := len(m.marshalWithoutBinders()) + helloBytes, err := m.marshalWithoutBinders() + if err != nil { + return err + } + lenWithoutBinders := len(helloBytes) b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, binder := range m.pskBinders { @@ -338,9 +345,11 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { } }) if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - panic("tls: internal error: failed to update binders") + return errors.New("tls: internal error: failed to update binders") } } + + return nil } func (m *clientHelloMsg) unmarshal(data []byte) bool { @@ -618,9 +627,98 @@ type serverHelloMsg struct { selectedGroup CurveID } -func (m *serverHelloMsg) marshal() []byte { +func (m *serverHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if m.ocspStapling { + exts.AddUint16(extensionStatusRequest) + exts.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + exts.AddUint16(extensionSessionTicket) + exts.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + exts.AddUint16(extensionSCT) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sct := range m.scts { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.serverShare.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -634,104 +732,15 @@ func (m *serverHelloMsg) marshal() []byte { b.AddUint16(m.cipherSuite) b.AddUint8(m.compressionMethod) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - b.AddUint16(extensionSessionTicket) - b.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range m.scts { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.serverShare.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } func (m *serverHelloMsg) unmarshal(data []byte) bool { @@ -855,9 +864,9 @@ type encryptedExtensionsMsg struct { alpnProtocol string } -func (m *encryptedExtensionsMsg) marshal() []byte { +func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -877,8 +886,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { @@ -926,10 +936,10 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { type endOfEarlyDataMsg struct{} -func (m *endOfEarlyDataMsg) marshal() []byte { +func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeEndOfEarlyData - return x + return x, nil } func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { @@ -941,9 +951,9 @@ type keyUpdateMsg struct { updateRequested bool } -func (m *keyUpdateMsg) marshal() []byte { +func (m *keyUpdateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -956,8 +966,9 @@ func (m *keyUpdateMsg) marshal() []byte { } }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *keyUpdateMsg) unmarshal(data []byte) bool { @@ -989,9 +1000,9 @@ type newSessionTicketMsgTLS13 struct { maxEarlyData uint32 } -func (m *newSessionTicketMsgTLS13) marshal() []byte { +func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1016,8 +1027,9 @@ func (m *newSessionTicketMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { @@ -1070,9 +1082,9 @@ type certificateRequestMsgTLS13 struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsgTLS13) marshal() []byte { +func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1131,8 +1143,9 @@ func (m *certificateRequestMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { @@ -1216,9 +1229,9 @@ type certificateMsg struct { certificates [][]byte } -func (m *certificateMsg) marshal() (x []byte) { +func (m *certificateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var i int @@ -1227,7 +1240,7 @@ func (m *certificateMsg) marshal() (x []byte) { } length := 3 + 3*len(m.certificates) + i - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificate x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1248,7 +1261,7 @@ func (m *certificateMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -1295,9 +1308,9 @@ type certificateMsgTLS13 struct { scts bool } -func (m *certificateMsgTLS13) marshal() []byte { +func (m *certificateMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1315,8 +1328,9 @@ func (m *certificateMsgTLS13) marshal() []byte { marshalCertificate(b, certificate) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { @@ -1439,9 +1453,9 @@ type serverKeyExchangeMsg struct { key []byte } -func (m *serverKeyExchangeMsg) marshal() []byte { +func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.key) x := make([]byte, length+4) @@ -1452,7 +1466,7 @@ func (m *serverKeyExchangeMsg) marshal() []byte { copy(x[4:], m.key) m.raw = x - return x + return x, nil } func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1469,9 +1483,9 @@ type certificateStatusMsg struct { response []byte } -func (m *certificateStatusMsg) marshal() []byte { +func (m *certificateStatusMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1483,8 +1497,9 @@ func (m *certificateStatusMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateStatusMsg) unmarshal(data []byte) bool { @@ -1503,10 +1518,10 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} -func (m *serverHelloDoneMsg) marshal() []byte { +func (m *serverHelloDoneMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeServerHelloDone - return x + return x, nil } func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { @@ -1518,9 +1533,9 @@ type clientKeyExchangeMsg struct { ciphertext []byte } -func (m *clientKeyExchangeMsg) marshal() []byte { +func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.ciphertext) x := make([]byte, length+4) @@ -1531,7 +1546,7 @@ func (m *clientKeyExchangeMsg) marshal() []byte { copy(x[4:], m.ciphertext) m.raw = x - return x + return x, nil } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1552,9 +1567,9 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) marshal() []byte { +func (m *finishedMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1563,8 +1578,9 @@ func (m *finishedMsg) marshal() []byte { b.AddBytes(m.verifyData) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *finishedMsg) unmarshal(data []byte) bool { @@ -1586,9 +1602,9 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsg) marshal() (x []byte) { +func (m *certificateRequestMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 4346, Section 7.4.4. @@ -1603,7 +1619,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { length += 2 + 2*len(m.supportedSignatureAlgorithms) } - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificateRequest x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1638,7 +1654,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateRequestMsg) unmarshal(data []byte) bool { @@ -1724,9 +1740,9 @@ type certificateVerifyMsg struct { signature []byte } -func (m *certificateVerifyMsg) marshal() (x []byte) { +func (m *certificateVerifyMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1740,8 +1756,9 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { @@ -1764,15 +1781,15 @@ type newSessionTicketMsg struct { ticket []byte } -func (m *newSessionTicketMsg) marshal() (x []byte) { +func (m *newSessionTicketMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 5077, Section 3.3. ticketLen := len(m.ticket) length := 2 + 4 + ticketLen - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeNewSessionTicket x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1783,7 +1800,7 @@ func (m *newSessionTicketMsg) marshal() (x []byte) { m.raw = x - return + return m.raw, nil } func (m *newSessionTicketMsg) unmarshal(data []byte) bool { @@ -1811,10 +1828,25 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool { type helloRequestMsg struct { } -func (*helloRequestMsg) marshal() []byte { - return []byte{typeHelloRequest, 0, 0, 0} +func (*helloRequestMsg) marshal() ([]byte, error) { + return []byte{typeHelloRequest, 0, 0, 0}, nil } func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } + +type transcriptHash interface { + Write([]byte) (int, error) +} + +// transcriptMsg is a helper used to marshal and hash messages which typically +// are not written to the wire, and as such aren't hashed during Conn.writeRecord. +func transcriptMsg(msg handshakeMessage, h transcriptHash) error { + data, err := msg.marshal() + if err != nil { + return err + } + h.Write(data) + return nil +} diff --git a/pkg/tls/handshake_messages_test.go b/pkg/tls/handshake_messages_test.go index c6fc8f2bf..206e2fb02 100644 --- a/pkg/tls/handshake_messages_test.go +++ b/pkg/tls/handshake_messages_test.go @@ -38,6 +38,15 @@ var tests = []any{ &certificateMsgTLS13{}, } +func mustMarshal(t *testing.T, msg handshakeMessage) []byte { + t.Helper() + b, err := msg.marshal() + if err != nil { + t.Fatal(err) + } + return b +} + func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -56,7 +65,7 @@ func TestMarshalUnmarshal(t *testing.T) { } m1 := v.Interface().(handshakeMessage) - marshaled := m1.marshal() + marshaled := mustMarshal(t, m1) m2 := iface.(handshakeMessage) if !m2.unmarshal(marshaled) { t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) @@ -409,12 +418,12 @@ func TestRejectEmptySCTList(t *testing.T) { var random [32]byte sct := []byte{0x42, 0x42, 0x42, 0x42} - serverHello := serverHelloMsg{ + serverHello := &serverHelloMsg{ vers: VersionTLS12, random: random[:], scts: [][]byte{sct}, } - serverHelloBytes := serverHello.marshal() + serverHelloBytes := mustMarshal(t, serverHello) var serverHelloCopy serverHelloMsg if !serverHelloCopy.unmarshal(serverHelloBytes) { @@ -452,12 +461,12 @@ func TestRejectEmptySCT(t *testing.T) { // not be zero length. var random [32]byte - serverHello := serverHelloMsg{ + serverHello := &serverHelloMsg{ vers: VersionTLS12, random: random[:], scts: [][]byte{nil}, } - serverHelloBytes := serverHello.marshal() + serverHelloBytes := mustMarshal(t, serverHello) var serverHelloCopy serverHelloMsg if serverHelloCopy.unmarshal(serverHelloBytes) { diff --git a/pkg/tls/handshake_server.go b/pkg/tls/handshake_server.go index b6d8f5c9b..0198aca3f 100644 --- a/pkg/tls/handshake_server.go +++ b/pkg/tls/handshake_server.go @@ -169,7 +169,9 @@ func (hs *serverHandshakeState) handshake() error { // readClientHello reads a ClientHello message and selects the protocol version. func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - msg, err := c.readHandshake() + // clientHelloMsg is included in the transcript, but we haven't initialized + // it yet. The respective handshake functions will record it themselves. + msg, err := c.readHandshake(nil) if err != nil { return nil, err } @@ -503,9 +505,10 @@ func (hs *serverHandshakeState) doResumeHandshake() error { hs.hello.ticketSupported = hs.sessionState.usedOldKey hs.finishedHash = newFinishedHash(c.vers, hs.suite) hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } @@ -543,24 +546,23 @@ func (hs *serverHandshakeState) doFullHandshakeStep1() error { // certificates won't be used. hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } certMsg := new(certificateMsg) certMsg.certificates = hs.cert.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } if hs.hello.ocspStapling { certStatus := new(certificateStatusMsg) certStatus.response = hs.cert.OCSPStaple - hs.finishedHash.Write(certStatus.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { return err } } @@ -572,8 +574,7 @@ func (hs *serverHandshakeState) doFullHandshakeStep1() error { return err } if skx != nil { - hs.finishedHash.Write(skx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err } } @@ -598,15 +599,13 @@ func (hs *serverHandshakeState) doFullHandshakeStep1() error { if c.config.ClientCAs != nil { hs.certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.finishedHash.Write(hs.certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.certReq, &hs.finishedHash); err != nil { return err } } helloDone := new(serverHelloDoneMsg) - hs.finishedHash.Write(helloDone.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { return err } @@ -619,7 +618,7 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { var pub crypto.PublicKey // public key for client auth, if any - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -632,7 +631,6 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) if err := c.processCertsFromClient(Certificate{ Certificate: certMsg.certificates, @@ -643,7 +641,7 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { pub = c.peerCertificates[0].PublicKey } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -661,7 +659,6 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(ckx, msg) } - hs.finishedHash.Write(ckx.marshal()) preMasterSecret, err := hs.keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) if err != nil { @@ -681,7 +678,10 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { // to the client's certificate. This allows us to verify that the client is in // possession of the private key of the certificate. if len(c.peerCertificates) > 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -716,7 +716,9 @@ func (hs *serverHandshakeState) doFullHandshakeStep2() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.finishedHash.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { + return err + } } hs.finishedHash.discardHandshakeBuffer() @@ -758,7 +760,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -775,7 +780,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash.Write(clientFinished.marshal()) + if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -809,14 +817,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { masterSecret: hs.masterSecret, certificates: certsFromClient, } - var err error - m.ticket, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + return err + } + m.ticket, err = c.encryptTicket(stateBytes) if err != nil { return err } - hs.finishedHash.Write(m.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { return err } @@ -826,14 +836,13 @@ func (hs *serverHandshakeState) sendSessionTicket() error { func (hs *serverHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } diff --git a/pkg/tls/handshake_server_tls13.go b/pkg/tls/handshake_server_tls13.go index ce879c67b..0d838824e 100644 --- a/pkg/tls/handshake_server_tls13.go +++ b/pkg/tls/handshake_server_tls13.go @@ -315,7 +315,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { c.sendAlert(alertInternalError) return errors.New("tls: internal error: failed to clone hash") } - transcript.Write(hs.clientHello.marshalWithoutBinders()) + clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + transcript.Write(clientHelloBytes) pskBinder := hs.suite.finishedHash(binderKey, transcript) if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { c.sendAlert(alertDecryptError) @@ -406,8 +411,7 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { @@ -415,7 +419,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. - hs.transcript.Write(hs.clientHello.marshal()) + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } chHash := hs.transcript.Sum(nil) hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) @@ -431,8 +437,7 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) selectedGroup: selectedGroup, } - hs.transcript.Write(helloRetryRequest.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { return err } @@ -440,7 +445,8 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) return err } - msg, err := c.readHandshake() + // clientHelloMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -531,9 +537,10 @@ func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { func (hs *serverHandshakeStateTLS13) sendServerParameters() error { c := hs.c - hs.transcript.Write(hs.clientHello.marshal()) - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } @@ -576,8 +583,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions.alpnProtocol = selectedProto c.clientProtocol = selectedProto - hs.transcript.Write(encryptedExtensions.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil { return err } @@ -606,8 +612,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.transcript.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { return err } } @@ -618,8 +623,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -650,8 +654,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -665,8 +668,7 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } @@ -727,7 +729,9 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { finishedMsg := &finishedMsg{ verifyData: hs.clientFinished, } - hs.transcript.Write(finishedMsg.marshal()) + if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { + return err + } if !hs.shouldSendSessionTickets() { return nil @@ -752,8 +756,12 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { SignedCertificateTimestamps: c.scts, }, } - var err error - m.label, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + m.label, err = c.encryptTicket(stateBytes) if err != nil { return err } @@ -772,7 +780,7 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { // ticket_nonce, which must be unique per connection, is always left at // zero because we only ever send one ticket per connection. - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(m, nil); err != nil { return err } @@ -797,7 +805,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { // If we requested a client certificate, then the client must send a // certificate message. If it's empty, no CertificateVerify is sent. - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -807,7 +815,6 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.transcript.Write(certMsg.marshal()) if err := c.processCertsFromClient(certMsg.certificate); err != nil { return err @@ -821,7 +828,10 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } if len(certMsg.certificate.Certificate) != 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -852,7 +862,9 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } } // If we waited until the client certificates to send session tickets, we @@ -867,7 +879,8 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { func (hs *serverHandshakeStateTLS13) readClientFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } diff --git a/pkg/tls/key_schedule.go b/pkg/tls/key_schedule.go index 8150d804a..ae8f80a7c 100644 --- a/pkg/tls/key_schedule.go +++ b/pkg/tls/key_schedule.go @@ -8,6 +8,7 @@ import ( "crypto/ecdh" "crypto/hmac" "errors" + "fmt" "hash" "io" @@ -40,8 +41,24 @@ func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []by hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(context) }) + hkdfLabelBytes, err := hkdfLabel.Bytes() + if err != nil { + // Rather than calling BytesOrPanic, we explicitly handle this error, in + // order to provide a reasonable error message. It should be basically + // impossible for this to panic, and routing errors back through the + // tree rooted in this function is quite painful. The labels are fixed + // size, and the context is either a fixed-length computed hash, or + // parsed from a field which has the same length limitation. As such, an + // error here is likely to only be caused during development. + // + // NOTE: another reasonable approach here might be to return a + // randomized slice if we encounter an error, which would break the + // connection, but avoid panicking. This would perhaps be safer but + // significantly more confusing to users. + panic(fmt.Errorf("failed to construct HKDF label: %s", err)) + } out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) if err != nil || n != length { panic("tls: HKDF-Expand-Label invocation failed unexpectedly") } diff --git a/pkg/tls/ticket.go b/pkg/tls/ticket.go index 6c1d20da2..b82ccd141 100644 --- a/pkg/tls/ticket.go +++ b/pkg/tls/ticket.go @@ -32,7 +32,7 @@ type sessionState struct { usedOldKey bool } -func (m *sessionState) marshal() []byte { +func (m *sessionState) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(m.vers) b.AddUint16(m.cipherSuite) @@ -47,7 +47,7 @@ func (m *sessionState) marshal() []byte { }) } }) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionState) unmarshal(data []byte) bool { @@ -86,7 +86,7 @@ type sessionStateTLS13 struct { certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; } -func (m *sessionStateTLS13) marshal() []byte { +func (m *sessionStateTLS13) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(VersionTLS13) b.AddUint8(0) // revision @@ -96,7 +96,7 @@ func (m *sessionStateTLS13) marshal() []byte { b.AddBytes(m.resumptionSecret) }) marshalCertificate(&b, m.certificate) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionStateTLS13) unmarshal(data []byte) bool { From 2b05f3294ed8a5ac9a1a44016ead867742d35ea6 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Mon, 20 Feb 2023 19:19:33 +0000 Subject: [PATCH 24/34] Fix: add missing ctx --- pkg/tls/handshake_client_tls13.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go index dffe17bab..3e83198b6 100644 --- a/pkg/tls/handshake_client_tls13.go +++ b/pkg/tls/handshake_client_tls13.go @@ -585,6 +585,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { AcceptableCAs: hs.certReq.certificateAuthorities, SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, Version: c.vers, + ctx: hs.ctx, }) if err != nil { return err From f45a29f304a2ce6d6216e0c360fc1a76699acce7 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Sat, 1 Apr 2023 07:06:09 +0000 Subject: [PATCH 25/34] opt: make the TLS implementation as an external library --- go.mod | 7 +- go.sum | 13 +- internal/boring/doc.go | 19 - internal/boring/notboring.go | 123 -- internal/boring/sig/sig.go | 17 - internal/boring/sig/sig_amd64.s | 54 - internal/boring/sig/sig_other.s | 20 - pkg/tls/alert.go | 99 -- pkg/tls/auth.go | 293 ----- pkg/tls/auth_test.go | 168 --- pkg/tls/bufLazy.go | 143 --- pkg/tls/bufLazy_test.go | 108 -- pkg/tls/cache.go | 95 -- pkg/tls/cache_test.go | 117 -- pkg/tls/cipher_suites.go | 702 ----------- pkg/tls/common.go | 1510 ----------------------- pkg/tls/common_string.go | 116 -- pkg/tls/conn.go | 1546 ----------------------- pkg/tls/generate_cert.go | 171 --- pkg/tls/go120.go | 87 ++ pkg/tls/go_oldversion.go | 5 + pkg/tls/handshake_client.go | 1074 ---------------- pkg/tls/handshake_client_tls13.go | 713 ----------- pkg/tls/handshake_messages.go | 1852 ---------------------------- pkg/tls/handshake_messages_test.go | 495 -------- pkg/tls/handshake_server.go | 934 -------------- pkg/tls/handshake_server_tls13.go | 902 -------------- pkg/tls/handshake_test.go | 530 -------- pkg/tls/handshake_unix_test.go | 18 - pkg/tls/key_agreement.go | 366 ------ pkg/tls/key_schedule.go | 158 --- pkg/tls/key_schedule_test.go | 175 --- pkg/tls/ktls.go | 17 - pkg/tls/ktls_cipher_linux.go | 414 ------- pkg/tls/ktls_io.go | 36 - pkg/tls/ktls_linux.go | 534 -------- pkg/tls/ktls_log_debug.go | 16 - pkg/tls/ktls_log_release.go | 8 - pkg/tls/ktls_others.go | 26 - pkg/tls/notboring.go | 20 - pkg/tls/prf.go | 283 ----- pkg/tls/prf_test.go | 140 --- pkg/tls/ticket.go | 185 --- pkg/tls/tls.go | 193 --- pkg/tls/tls_test.go | 25 - 45 files changed, 104 insertions(+), 14423 deletions(-) delete mode 100644 internal/boring/doc.go delete mode 100644 internal/boring/notboring.go delete mode 100644 internal/boring/sig/sig.go delete mode 100644 internal/boring/sig/sig_amd64.s delete mode 100644 internal/boring/sig/sig_other.s delete mode 100644 pkg/tls/alert.go delete mode 100644 pkg/tls/auth.go delete mode 100644 pkg/tls/auth_test.go delete mode 100644 pkg/tls/bufLazy.go delete mode 100644 pkg/tls/bufLazy_test.go delete mode 100644 pkg/tls/cache.go delete mode 100644 pkg/tls/cache_test.go delete mode 100644 pkg/tls/cipher_suites.go delete mode 100644 pkg/tls/common.go delete mode 100644 pkg/tls/common_string.go delete mode 100644 pkg/tls/conn.go delete mode 100644 pkg/tls/generate_cert.go create mode 100644 pkg/tls/go120.go create mode 100644 pkg/tls/go_oldversion.go delete mode 100644 pkg/tls/handshake_client.go delete mode 100644 pkg/tls/handshake_client_tls13.go delete mode 100644 pkg/tls/handshake_messages.go delete mode 100644 pkg/tls/handshake_messages_test.go delete mode 100644 pkg/tls/handshake_server.go delete mode 100644 pkg/tls/handshake_server_tls13.go delete mode 100644 pkg/tls/handshake_test.go delete mode 100644 pkg/tls/handshake_unix_test.go delete mode 100644 pkg/tls/key_agreement.go delete mode 100644 pkg/tls/key_schedule.go delete mode 100644 pkg/tls/key_schedule_test.go delete mode 100644 pkg/tls/ktls.go delete mode 100644 pkg/tls/ktls_cipher_linux.go delete mode 100644 pkg/tls/ktls_io.go delete mode 100644 pkg/tls/ktls_linux.go delete mode 100644 pkg/tls/ktls_log_debug.go delete mode 100644 pkg/tls/ktls_log_release.go delete mode 100644 pkg/tls/ktls_others.go delete mode 100644 pkg/tls/notboring.go delete mode 100644 pkg/tls/prf.go delete mode 100644 pkg/tls/prf_test.go delete mode 100644 pkg/tls/ticket.go delete mode 100644 pkg/tls/tls.go delete mode 100644 pkg/tls/tls_test.go diff --git a/go.mod b/go.mod index 97216c37c..b048f0d46 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,13 @@ module github.com/panjf2000/gnet/v2 require ( + github.com/0-haha/gnet_go_tls/v120 v120.0.1 github.com/panjf2000/ants/v2 v2.7.1 - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.8.2 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/zap v1.21.0 - golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 - golang.org/x/sys v0.3.0 + golang.org/x/crypto v0.5.0 + golang.org/x/sys v0.4.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) diff --git a/go.sum b/go.sum index 8a93adf87..fb5427c7e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/0-haha/gnet_go_tls/v120 v120.0.1 h1:oFVjzpqQO4k3MfbU211oFSx8sRSALXto7T5u2bbYqTI= +github.com/0-haha/gnet_go_tls/v120 v120.0.1/go.mod h1:ZDwYfvBBzRwvZNENOXeVI+QjVDKr5r8aEKLMDSchRkI= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= @@ -23,8 +25,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -40,8 +43,8 @@ go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -57,8 +60,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/internal/boring/doc.go b/internal/boring/doc.go deleted file mode 100644 index 6060fe595..000000000 --- a/internal/boring/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package boring provides access to BoringCrypto implementation functions. -// Check the constant Enabled to find out whether BoringCrypto is available. -// If BoringCrypto is not available, the functions in this package all panic. -package boring - -// Enabled reports whether BoringCrypto is available. -// When enabled is false, all functions in this package panic. -// -// BoringCrypto is only available on linux/amd64 systems. -const Enabled = available - -// A BigInt is the raw words from a BigInt. -// This definition allows us to avoid importing math/big. -// Conversion between BigInt and *big.Int is in crypto/internal/boring/bbig. -type BigInt []uint diff --git a/internal/boring/notboring.go b/internal/boring/notboring.go deleted file mode 100644 index 6341d5b16..000000000 --- a/internal/boring/notboring.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !(boringcrypto && linux && (amd64 || arm64) && !android && !cmd_go_bootstrap && !msan && cgo) - -package boring - -import ( - "crypto" - "crypto/cipher" - "hash" - - "github.com/panjf2000/gnet/v2/internal/boring/sig" -) - -const available = false - -// Unreachable marks code that should be unreachable -// when BoringCrypto is in use. It is a no-op without BoringCrypto. -func Unreachable() { - // Code that's unreachable when using BoringCrypto - // is exactly the code we want to detect for reporting - // standard Go crypto. - sig.StandardCrypto() -} - -// UnreachableExceptTests marks code that should be unreachable -// when BoringCrypto is in use. It is a no-op without BoringCrypto. -func UnreachableExceptTests() {} - -type randReader int - -func (randReader) Read(b []byte) (int, error) { panic("boringcrypto: not available") } - -const RandReader = randReader(0) - -func NewSHA1() hash.Hash { panic("boringcrypto: not available") } -func NewSHA224() hash.Hash { panic("boringcrypto: not available") } -func NewSHA256() hash.Hash { panic("boringcrypto: not available") } -func NewSHA384() hash.Hash { panic("boringcrypto: not available") } -func NewSHA512() hash.Hash { panic("boringcrypto: not available") } - -func SHA1([]byte) [20]byte { panic("boringcrypto: not available") } -func SHA224([]byte) [28]byte { panic("boringcrypto: not available") } -func SHA256([]byte) [32]byte { panic("boringcrypto: not available") } -func SHA384([]byte) [48]byte { panic("boringcrypto: not available") } -func SHA512([]byte) [64]byte { panic("boringcrypto: not available") } - -func NewHMAC(h func() hash.Hash, key []byte) hash.Hash { panic("boringcrypto: not available") } - -func NewAESCipher(key []byte) (cipher.Block, error) { panic("boringcrypto: not available") } -func NewGCMTLS(cipher.Block) (cipher.AEAD, error) { panic("boringcrypto: not available") } - -type PublicKeyECDSA struct{ _ int } -type PrivateKeyECDSA struct{ _ int } - -func GenerateKeyECDSA(curve string) (X, Y, D BigInt, err error) { - panic("boringcrypto: not available") -} -func NewPrivateKeyECDSA(curve string, X, Y, D BigInt) (*PrivateKeyECDSA, error) { - panic("boringcrypto: not available") -} -func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) { - panic("boringcrypto: not available") -} -func SignMarshalECDSA(priv *PrivateKeyECDSA, hash []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func VerifyECDSA(pub *PublicKeyECDSA, hash []byte, sig []byte) bool { - panic("boringcrypto: not available") -} - -type PublicKeyRSA struct{ _ int } -type PrivateKeyRSA struct{ _ int } - -func DecryptRSAOAEP(h, mgfHash hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func EncryptRSAOAEP(h, mgfHash hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { - panic("boringcrypto: not available") -} -func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) { - panic("boringcrypto: not available") -} -func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) { panic("boringcrypto: not available") } -func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { - panic("boringcrypto: not available") -} -func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) { - panic("boringcrypto: not available") -} -func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { - panic("boringcrypto: not available") -} -func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error { - panic("boringcrypto: not available") -} - -type PublicKeyECDH struct{} -type PrivateKeyECDH struct{} - -func ECDH(*PrivateKeyECDH, *PublicKeyECDH) ([]byte, error) { panic("boringcrypto: not available") } -func GenerateKeyECDH(string) (*PrivateKeyECDH, []byte, error) { panic("boringcrypto: not available") } -func NewPrivateKeyECDH(string, []byte) (*PrivateKeyECDH, error) { panic("boringcrypto: not available") } -func NewPublicKeyECDH(string, []byte) (*PublicKeyECDH, error) { panic("boringcrypto: not available") } -func (*PublicKeyECDH) Bytes() []byte { panic("boringcrypto: not available") } -func (*PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { panic("boringcrypto: not available") } diff --git a/internal/boring/sig/sig.go b/internal/boring/sig/sig.go deleted file mode 100644 index 716c03c5e..000000000 --- a/internal/boring/sig/sig.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package sig holds “code signatures” that can be called -// and will result in certain code sequences being linked into -// the final binary. The functions themselves are no-ops. -package sig - -// BoringCrypto indicates that the BoringCrypto module is present. -func BoringCrypto() - -// FIPSOnly indicates that package crypto/tls/fipsonly is present. -func FIPSOnly() - -// StandardCrypto indicates that standard Go crypto is present. -func StandardCrypto() diff --git a/internal/boring/sig/sig_amd64.s b/internal/boring/sig/sig_amd64.s deleted file mode 100644 index 64e3462e4..000000000 --- a/internal/boring/sig/sig_amd64.s +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#include "textflag.h" - -// These functions are no-ops, but you can search for their implementations -// to find out whether they are linked into a particular binary. -// -// Each function consists of a two-byte jump over the next 29-bytes, -// then a 5-byte indicator sequence unlikely to occur in real x86 instructions, -// then a randomly-chosen 24-byte sequence, and finally a return instruction -// (the target of the jump). -// -// These sequences are known to rsc.io/goversion. - -#define START \ - BYTE $0xEB; BYTE $0x1D; BYTE $0xF4; BYTE $0x48; BYTE $0xF4; BYTE $0x4B; BYTE $0xF4 - -#define END \ - BYTE $0xC3 - -// BoringCrypto indicates that BoringCrypto (in particular, its func init) is present. -TEXT ·BoringCrypto(SB),NOSPLIT,$0 - START - BYTE $0xB3; BYTE $0x32; BYTE $0xF5; BYTE $0x28; - BYTE $0x13; BYTE $0xA3; BYTE $0xB4; BYTE $0x50; - BYTE $0xD4; BYTE $0x41; BYTE $0xCC; BYTE $0x24; - BYTE $0x85; BYTE $0xF0; BYTE $0x01; BYTE $0x45; - BYTE $0x4E; BYTE $0x92; BYTE $0x10; BYTE $0x1B; - BYTE $0x1D; BYTE $0x2F; BYTE $0x19; BYTE $0x50; - END - -// StandardCrypto indicates that standard Go crypto is present. -TEXT ·StandardCrypto(SB),NOSPLIT,$0 - START - BYTE $0xba; BYTE $0xee; BYTE $0x4d; BYTE $0xfa; - BYTE $0x98; BYTE $0x51; BYTE $0xca; BYTE $0x56; - BYTE $0xa9; BYTE $0x11; BYTE $0x45; BYTE $0xe8; - BYTE $0x3e; BYTE $0x99; BYTE $0xc5; BYTE $0x9c; - BYTE $0xf9; BYTE $0x11; BYTE $0xcb; BYTE $0x8e; - BYTE $0x80; BYTE $0xda; BYTE $0xf1; BYTE $0x2f; - END - -// FIPSOnly indicates that crypto/tls/fipsonly is present. -TEXT ·FIPSOnly(SB),NOSPLIT,$0 - START - BYTE $0x36; BYTE $0x3C; BYTE $0xB9; BYTE $0xCE; - BYTE $0x9D; BYTE $0x68; BYTE $0x04; BYTE $0x7D; - BYTE $0x31; BYTE $0xF2; BYTE $0x8D; BYTE $0x32; - BYTE $0x5D; BYTE $0x5C; BYTE $0xA5; BYTE $0x87; - BYTE $0x3F; BYTE $0x5D; BYTE $0x80; BYTE $0xCA; - BYTE $0xF6; BYTE $0xD6; BYTE $0x15; BYTE $0x1B; - END diff --git a/internal/boring/sig/sig_other.s b/internal/boring/sig/sig_other.s deleted file mode 100644 index 2bbb1df30..000000000 --- a/internal/boring/sig/sig_other.s +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// These functions are no-ops. -// On amd64 they have recognizable implementations, so that you can -// search a particular binary to see if they are present. -// On other platforms (those using this source file), they don't. - -//go:build !amd64 -// +build !amd64 - -TEXT ·BoringCrypto(SB),$0 - RET - -TEXT ·FIPSOnly(SB),$0 - RET - -TEXT ·StandardCrypto(SB),$0 - RET diff --git a/pkg/tls/alert.go b/pkg/tls/alert.go deleted file mode 100644 index 4790b7372..000000000 --- a/pkg/tls/alert.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import "strconv" - -type alert uint8 - -const ( - // alert level - alertLevelWarning = 1 - alertLevelError = 2 -) - -const ( - alertCloseNotify alert = 0 - alertUnexpectedMessage alert = 10 - alertBadRecordMAC alert = 20 - alertDecryptionFailed alert = 21 - alertRecordOverflow alert = 22 - alertDecompressionFailure alert = 30 - alertHandshakeFailure alert = 40 - alertBadCertificate alert = 42 - alertUnsupportedCertificate alert = 43 - alertCertificateRevoked alert = 44 - alertCertificateExpired alert = 45 - alertCertificateUnknown alert = 46 - alertIllegalParameter alert = 47 - alertUnknownCA alert = 48 - alertAccessDenied alert = 49 - alertDecodeError alert = 50 - alertDecryptError alert = 51 - alertExportRestriction alert = 60 - alertProtocolVersion alert = 70 - alertInsufficientSecurity alert = 71 - alertInternalError alert = 80 - alertInappropriateFallback alert = 86 - alertUserCanceled alert = 90 - alertNoRenegotiation alert = 100 - alertMissingExtension alert = 109 - alertUnsupportedExtension alert = 110 - alertCertificateUnobtainable alert = 111 - alertUnrecognizedName alert = 112 - alertBadCertificateStatusResponse alert = 113 - alertBadCertificateHashValue alert = 114 - alertUnknownPSKIdentity alert = 115 - alertCertificateRequired alert = 116 - alertNoApplicationProtocol alert = 120 -) - -var alertText = map[alert]string{ - alertCloseNotify: "close notify", - alertUnexpectedMessage: "unexpected message", - alertBadRecordMAC: "bad record MAC", - alertDecryptionFailed: "decryption failed", - alertRecordOverflow: "record overflow", - alertDecompressionFailure: "decompression failure", - alertHandshakeFailure: "handshake failure", - alertBadCertificate: "bad certificate", - alertUnsupportedCertificate: "unsupported certificate", - alertCertificateRevoked: "revoked certificate", - alertCertificateExpired: "expired certificate", - alertCertificateUnknown: "unknown certificate", - alertIllegalParameter: "illegal parameter", - alertUnknownCA: "unknown certificate authority", - alertAccessDenied: "access denied", - alertDecodeError: "error decoding message", - alertDecryptError: "error decrypting message", - alertExportRestriction: "export restriction", - alertProtocolVersion: "protocol version not supported", - alertInsufficientSecurity: "insufficient security level", - alertInternalError: "internal error", - alertInappropriateFallback: "inappropriate fallback", - alertUserCanceled: "user canceled", - alertNoRenegotiation: "no renegotiation", - alertMissingExtension: "missing extension", - alertUnsupportedExtension: "unsupported extension", - alertCertificateUnobtainable: "certificate unobtainable", - alertUnrecognizedName: "unrecognized name", - alertBadCertificateStatusResponse: "bad certificate status response", - alertBadCertificateHashValue: "bad certificate hash value", - alertUnknownPSKIdentity: "unknown PSK identity", - alertCertificateRequired: "certificate required", - alertNoApplicationProtocol: "no application protocol", -} - -func (e alert) String() string { - s, ok := alertText[e] - if ok { - return "tls: " + s - } - return "tls: alert(" + strconv.Itoa(int(e)) + ")" -} - -func (e alert) Error() string { - return e.String() -} diff --git a/pkg/tls/auth.go b/pkg/tls/auth.go deleted file mode 100644 index 7c5675c6d..000000000 --- a/pkg/tls/auth.go +++ /dev/null @@ -1,293 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rsa" - "errors" - "fmt" - "hash" - "io" -) - -// verifyHandshakeSignature verifies a signature against pre-hashed -// (if required) handshake contents. -func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { - switch sigType { - case signatureECDSA: - pubKey, ok := pubkey.(*ecdsa.PublicKey) - if !ok { - return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) - } - if !ecdsa.VerifyASN1(pubKey, signed, sig) { - return errors.New("ECDSA verification failure") - } - case signatureEd25519: - pubKey, ok := pubkey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) - } - if !ed25519.Verify(pubKey, signed, sig) { - return errors.New("Ed25519 verification failure") - } - case signaturePKCS1v15: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { - return err - } - case signatureRSAPSS: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} - if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { - return err - } - default: - return errors.New("internal error: unknown signature type") - } - return nil -} - -const ( - serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" - clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" -) - -var signaturePadding = []byte{ - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, -} - -// signedMessage returns the pre-hashed (if necessary) message to be signed by -// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. -func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { - if sigHash == directSigning { - b := &bytes.Buffer{} - b.Write(signaturePadding) - io.WriteString(b, context) - b.Write(transcript.Sum(nil)) - return b.Bytes() - } - h := sigHash.New() - h.Write(signaturePadding) - io.WriteString(h, context) - h.Write(transcript.Sum(nil)) - return h.Sum(nil) -} - -// typeAndHashFromSignatureScheme returns the corresponding signature type and -// crypto.Hash for a given TLS SignatureScheme. -func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { - switch signatureAlgorithm { - case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: - sigType = signaturePKCS1v15 - case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: - sigType = signatureRSAPSS - case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: - sigType = signatureECDSA - case Ed25519: - sigType = signatureEd25519 - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - switch signatureAlgorithm { - case PKCS1WithSHA1, ECDSAWithSHA1: - hash = crypto.SHA1 - case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: - hash = crypto.SHA256 - case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: - hash = crypto.SHA384 - case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: - hash = crypto.SHA512 - case Ed25519: - hash = directSigning - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - return sigType, hash, nil -} - -// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for -// a given public key used with TLS 1.0 and 1.1, before the introduction of -// signature algorithm negotiation. -func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { - switch pub.(type) { - case *rsa.PublicKey: - return signaturePKCS1v15, crypto.MD5SHA1, nil - case *ecdsa.PublicKey: - return signatureECDSA, crypto.SHA1, nil - case ed25519.PublicKey: - // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, - // but it requires holding on to a handshake transcript to do a - // full signature, and not even OpenSSL bothers with the - // complexity, so we can't even test it properly. - return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") - default: - return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) - } -} - -var rsaSignatureSchemes = []struct { - scheme SignatureScheme - minModulusBytes int - maxVersion uint16 -}{ - // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires - // emLen >= hLen + sLen + 2 - {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, - // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires - // emLen >= len(prefix) + hLen + 11 - // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. - {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, - {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, - {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, - {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, -} - -// signatureSchemesForCertificate returns the list of supported SignatureSchemes -// for a given certificate, based on the public key and the protocol version, -// and optionally filtered by its explicit SupportedSignatureAlgorithms. -// -// This function must be kept in sync with supportedSignatureAlgorithms. -// FIPS filtering is applied in the caller, selectSignatureScheme. -func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil - } - - var sigAlgs []SignatureScheme - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - if version != VersionTLS13 { - // In TLS 1.2 and earlier, ECDSA algorithms are not - // constrained to a single curve. - sigAlgs = []SignatureScheme{ - ECDSAWithP256AndSHA256, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - ECDSAWithSHA1, - } - break - } - switch pub.Curve { - case elliptic.P256(): - sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} - case elliptic.P384(): - sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} - case elliptic.P521(): - sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} - default: - return nil - } - case *rsa.PublicKey: - size := pub.Size() - sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) - for _, candidate := range rsaSignatureSchemes { - if size >= candidate.minModulusBytes && version <= candidate.maxVersion { - sigAlgs = append(sigAlgs, candidate.scheme) - } - } - case ed25519.PublicKey: - sigAlgs = []SignatureScheme{Ed25519} - default: - return nil - } - - if cert.SupportedSignatureAlgorithms != nil { - var filteredSigAlgs []SignatureScheme - for _, sigAlg := range sigAlgs { - if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { - filteredSigAlgs = append(filteredSigAlgs, sigAlg) - } - } - return filteredSigAlgs - } - return sigAlgs -} - -// selectSignatureScheme picks a SignatureScheme from the peer's preference list -// that works with the selected certificate. It's only called for protocol -// versions that support signature algorithms, so TLS 1.2 and 1.3. -func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { - supportedAlgs := signatureSchemesForCertificate(vers, c) - if len(supportedAlgs) == 0 { - return 0, unsupportedCertificateError(c) - } - if len(peerAlgs) == 0 && vers == VersionTLS12 { - // For TLS 1.2, if the client didn't send signature_algorithms then we - // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. - peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} - } - // Pick signature scheme in the peer's preference order, as our - // preference order is not configurable. - for _, preferredAlg := range peerAlgs { - if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) { - continue - } - if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { - return preferredAlg, nil - } - } - return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") -} - -// unsupportedCertificateError returns a helpful error for certificates with -// an unsupported private key. -func unsupportedCertificateError(cert *Certificate) error { - switch cert.PrivateKey.(type) { - case rsa.PrivateKey, ecdsa.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", - cert.PrivateKey, cert.PrivateKey) - case *ed25519.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") - } - - signer, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", - cert.PrivateKey) - } - - switch pub := signer.Public().(type) { - case *ecdsa.PublicKey: - switch pub.Curve { - case elliptic.P256(): - case elliptic.P384(): - case elliptic.P521(): - default: - return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) - } - case *rsa.PublicKey: - return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") - case ed25519.PublicKey: - default: - return fmt.Errorf("tls: unsupported certificate key (%T)", pub) - } - - if cert.SupportedSignatureAlgorithms != nil { - return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") - } - - return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) -} diff --git a/pkg/tls/auth_test.go b/pkg/tls/auth_test.go deleted file mode 100644 index c23d93f3c..000000000 --- a/pkg/tls/auth_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto" - "testing" -) - -func TestSignatureSelection(t *testing.T) { - rsaCert := &Certificate{ - Certificate: [][]byte{testRSACertificate}, - PrivateKey: testRSAPrivateKey, - } - pkcs1Cert := &Certificate{ - Certificate: [][]byte{testRSACertificate}, - PrivateKey: testRSAPrivateKey, - SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, - } - ecdsaCert := &Certificate{ - Certificate: [][]byte{testP256Certificate}, - PrivateKey: testP256PrivateKey, - } - ed25519Cert := &Certificate{ - Certificate: [][]byte{testEd25519Certificate}, - PrivateKey: testEd25519PrivateKey, - } - - tests := []struct { - cert *Certificate - peerSigAlgs []SignatureScheme - tlsVersion uint16 - - expectedSigAlg SignatureScheme - expectedSigType uint8 - expectedHash crypto.Hash - }{ - {rsaCert, []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1}, - {rsaCert, []SignatureScheme{PKCS1WithSHA512, PKCS1WithSHA1}, VersionTLS12, PKCS1WithSHA512, signaturePKCS1v15, crypto.SHA512}, - {rsaCert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PSSWithSHA256, signatureRSAPSS, crypto.SHA256}, - {pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA256, signaturePKCS1v15, crypto.SHA256}, - {rsaCert, []SignatureScheme{PSSWithSHA384, PKCS1WithSHA1}, VersionTLS13, PSSWithSHA384, signatureRSAPSS, crypto.SHA384}, - {ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1}, - {ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS12, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256}, - {ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS13, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256}, - {ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS12, Ed25519, signatureEd25519, directSigning}, - {ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS13, Ed25519, signatureEd25519, directSigning}, - - // TLS 1.2 without signature_algorithms extension - {rsaCert, nil, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1}, - {ecdsaCert, nil, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1}, - - // TLS 1.2 does not restrict the ECDSA curve (our ecdsaCert is P-256) - {ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS12, ECDSAWithP384AndSHA384, signatureECDSA, crypto.SHA384}, - } - - for testNo, test := range tests { - sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs) - if err != nil { - t.Errorf("test[%d]: unexpected selectSignatureScheme error: %v", testNo, err) - } - if test.expectedSigAlg != sigAlg { - t.Errorf("test[%d]: expected signature scheme %v, got %v", testNo, test.expectedSigAlg, sigAlg) - } - sigType, hashFunc, err := typeAndHashFromSignatureScheme(sigAlg) - if err != nil { - t.Errorf("test[%d]: unexpected typeAndHashFromSignatureScheme error: %v", testNo, err) - } - if test.expectedSigType != sigType { - t.Errorf("test[%d]: expected signature algorithm %#x, got %#x", testNo, test.expectedSigType, sigType) - } - if test.expectedHash != hashFunc { - t.Errorf("test[%d]: expected hash function %#x, got %#x", testNo, test.expectedHash, hashFunc) - } - } - - brokenCert := &Certificate{ - Certificate: [][]byte{testRSACertificate}, - PrivateKey: testRSAPrivateKey, - SupportedSignatureAlgorithms: []SignatureScheme{Ed25519}, - } - - badTests := []struct { - cert *Certificate - peerSigAlgs []SignatureScheme - tlsVersion uint16 - }{ - {rsaCert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12}, - {ecdsaCert, []SignatureScheme{PKCS1WithSHA256, PKCS1WithSHA1}, VersionTLS12}, - {rsaCert, []SignatureScheme{0}, VersionTLS12}, - {ed25519Cert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12}, - {ecdsaCert, []SignatureScheme{Ed25519}, VersionTLS12}, - {brokenCert, []SignatureScheme{Ed25519}, VersionTLS12}, - {brokenCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS12}, - // RFC 5246, Section 7.4.1.4.1, says to only consider {sha1,ecdsa} as - // default when the extension is missing, and RFC 8422 does not update - // it. Anyway, if a stack supports Ed25519 it better support sigalgs. - {ed25519Cert, nil, VersionTLS12}, - // TLS 1.3 has no default signature_algorithms. - {rsaCert, nil, VersionTLS13}, - {ecdsaCert, nil, VersionTLS13}, - {ed25519Cert, nil, VersionTLS13}, - // Wrong curve, which TLS 1.3 checks - {ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS13}, - // TLS 1.3 does not support PKCS1v1.5 or SHA-1. - {rsaCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS13}, - {pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS13}, - {ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS13}, - // The key can be too small for the hash. - {rsaCert, []SignatureScheme{PSSWithSHA512}, VersionTLS12}, - } - - for testNo, test := range badTests { - sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs) - if err == nil { - t.Errorf("test[%d]: unexpected success, got %v", testNo, sigAlg) - } - } -} - -func TestLegacyTypeAndHash(t *testing.T) { - sigType, hashFunc, err := legacyTypeAndHashFromPublicKey(testRSAPrivateKey.Public()) - if err != nil { - t.Errorf("RSA: unexpected error: %v", err) - } - if expectedSigType := signaturePKCS1v15; expectedSigType != sigType { - t.Errorf("RSA: expected signature type %#x, got %#x", expectedSigType, sigType) - } - if expectedHashFunc := crypto.MD5SHA1; expectedHashFunc != hashFunc { - t.Errorf("RSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc) - } - - sigType, hashFunc, err = legacyTypeAndHashFromPublicKey(testECDSAPrivateKey.Public()) - if err != nil { - t.Errorf("ECDSA: unexpected error: %v", err) - } - if expectedSigType := signatureECDSA; expectedSigType != sigType { - t.Errorf("ECDSA: expected signature type %#x, got %#x", expectedSigType, sigType) - } - if expectedHashFunc := crypto.SHA1; expectedHashFunc != hashFunc { - t.Errorf("ECDSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc) - } - - // Ed25519 is not supported by TLS 1.0 and 1.1. - _, _, err = legacyTypeAndHashFromPublicKey(testEd25519PrivateKey.Public()) - if err == nil { - t.Errorf("Ed25519: unexpected success") - } -} - -// TestSupportedSignatureAlgorithms checks that all supportedSignatureAlgorithms -// have valid type and hash information. -func TestSupportedSignatureAlgorithms(t *testing.T) { - for _, sigAlg := range supportedSignatureAlgorithms() { - sigType, hash, err := typeAndHashFromSignatureScheme(sigAlg) - if err != nil { - t.Errorf("%v: unexpected error: %v", sigAlg, err) - } - if sigType == 0 { - t.Errorf("%v: missing signature type", sigAlg) - } - if hash == 0 && sigAlg != Ed25519 { - t.Errorf("%v: missing hash", sigAlg) - } - } -} diff --git a/pkg/tls/bufLazy.go b/pkg/tls/bufLazy.go deleted file mode 100644 index bfcda3411..000000000 --- a/pkg/tls/bufLazy.go +++ /dev/null @@ -1,143 +0,0 @@ -package tls - -import ( - "sync" -) - -const defaultSize = 4096 - -var bytePool = sync.Pool{ - New: func() any { - buf := make([]byte, defaultSize) - return &buf - }, -} - -type LazyBuffer struct { - buf []byte - ref *[]byte -} - -func (lb *LazyBuffer) Bytes() []byte { - return lb.buf -} - -func (lb *LazyBuffer) Len() int { - return len(lb.buf) -} - -func (lb *LazyBuffer) Next(n int) []byte { - m := lb.Len() - if n > m { - n = m - } - data := lb.buf[:n] - lb.buf = lb.buf[n:] - return data -} - -func (lb *LazyBuffer) tryGrowByReslice(n int) (int, bool) { - if l := len(lb.buf); n <= cap(lb.buf)-l { - lb.buf = lb.buf[:l+n] - return l, true - } - return 0, false -} - -func (lb *LazyBuffer) growSlice(n int) { - // TODO(http://golang.org/issue/51462): We should rely on the append-make - // pattern so that the compiler can call runtime.growslice. For example: - // return append(b, make([]byte, n)...) - // This avoids unnecessary zero-ing of the first len(b) bytes of the - // allocated slice, but this pattern causes b to escape onto the heap. - // - // Instead use the append-make pattern with a nil slice to ensure that - // we allocate buffers rounded up to the closest size class. - c := len(lb.buf) + n // ensure enough space for n elements - if c < 2*cap(*lb.ref) { - // The growth rate has historically always been 2x. In the future, - // we could rely purely on append to determine the growth rate. - c = 2 * cap(*lb.ref) - } - b2 := append([]byte(nil), make([]byte, c)...) - copy(b2, lb.buf) - lb.buf = b2 - lb.ref = &b2 -} - -func (lb *LazyBuffer) grow(n int) int { - m := lb.Len() - // Try to grow by means of a reslice. - if i, ok := lb.tryGrowByReslice(n); ok { - return i - } - - if lb.ref == nil { - lb.ref = bytePool.Get().(*[]byte) - } - - c := cap(*lb.ref) - if n <= c/2-m { - // We can slide things down instead of allocating a new - // slice. We only need m+n <= c to slide, but - // we instead let capacity get twice as large so we - // don't spend all our time copying. - copy(*lb.ref, lb.buf) - } else { - // Add b.off to account for b.buf[:b.off] being sliced off the front. - lb.growSlice(n) - } - - lb.buf = (*lb.ref)[:m+n] - return m -} - -func (lb *LazyBuffer) Grow(n int) { - m := lb.grow(n) - lb.buf = lb.buf[:m] -} - -func (lb *LazyBuffer) Extend(n int) { - lb.grow(n) -} - -func (lb *LazyBuffer) Truncate(n int) { - if 0 <= n && n <= len(lb.buf) { - lb.buf = lb.buf[:n] - } -} - -func (lb *LazyBuffer) Write(p []byte) (n int, err error) { - m := len(lb.buf) - if lb.ref == nil { - oldData := lb.buf - lb.ref = bytePool.Get().(*[]byte) - lb.buf = (*lb.ref)[:0] - lb.grow(m + len(p)) - copy(lb.buf, oldData) - } else { - lb.grow(len(p)) - } - - return copy(lb.buf[m:], p), nil -} - -func (lb *LazyBuffer) Set(p []byte) { - if lb.ref == nil && lb.Len() == 0 { - lb.buf = p - } else { - lb.Write(p) - } -} - -func (lb *LazyBuffer) Done() { - lb.buf = nil - if lb.ref != nil { - bytePool.Put(lb.ref) - lb.ref = nil - } -} - -func (lb *LazyBuffer) IsLazy() bool { - return lb.ref == nil -} diff --git a/pkg/tls/bufLazy_test.go b/pkg/tls/bufLazy_test.go deleted file mode 100644 index 61064596c..000000000 --- a/pkg/tls/bufLazy_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package tls - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func eval(t *testing.T, buf *LazyBuffer, data []byte, n int, isLazy bool) { - assert.EqualValues(t, buf.Bytes(), data) - assert.EqualValues(t, buf.Len(), n) - assert.EqualValues(t, buf.IsLazy(), isLazy) -} - -func TestBufLazyEmpty(t *testing.T) { - var buf LazyBuffer - eval(t, &buf, []byte(nil), 0, true) - buf.Write(nil) - eval(t, &buf, []byte{}, 0, false) -} - -func TestBufLazyLazyMode(t *testing.T) { - var buf LazyBuffer - var data []byte = []byte("Hello World!") - buf.Set(data) - eval(t, &buf, data, len(data), true) - - // Next 1 byte - assert.EqualValues(t, buf.Next(1), data[:1]) - eval(t, &buf, data[1:], len(data)-1, true) - - // Next remaining byte - assert.EqualValues(t, buf.Next(len(data)-1), data[1:]) - eval(t, &buf, make([]byte, 0), 0, true) - - // Next 1 more byte - assert.EqualValues(t, buf.Next(1), data[:0]) - eval(t, &buf, make([]byte, 0), 0, true) - - // Reset the data, must be in lazy mode as the previous data is drained - buf.Set(data) - eval(t, &buf, data, len(data), true) - - // Next all byte + 1 - assert.EqualValues(t, buf.Next(len(data)+1), data) - eval(t, &buf, make([]byte, 0), 0, true) - - // Done - buf.Done() - eval(t, &buf, []byte(nil), 0, true) -} - -func TestBufLazyLazyToWriteMode(t *testing.T) { - var buf LazyBuffer - var data []byte = []byte("Hello World!") - buf.Set(data) - eval(t, &buf, data, len(data), true) - - // switch to write - buf.Set(data) - doubleData := append(data, data...) - eval(t, &buf, doubleData, len(doubleData), false) - - // append new data - doubleData = append(doubleData, data...) - buf.Set(data) - eval(t, &buf, doubleData, len(doubleData), false) - - buf.Done() - eval(t, &buf, []byte(nil), 0, true) -} - -func TestBufLazyWriteMode(t *testing.T) { - var buf LazyBuffer - var data []byte = []byte(strings.Repeat("A", defaultSize)) - - buf.Grow(defaultSize) - // fill up the default buffer - buf.Write(data) - eval(t, &buf, data, len(data), false) - - // grow the buffer - buf.Write(data[:1]) - doubleData := append(data, data...) - eval(t, &buf, doubleData[:len(data)+1], len(data)+1, false) - - // fill up the remaining buffer - buf.Write(data[1:]) - eval(t, &buf, doubleData, len(doubleData), false) - - // consume half of the data - assert.EqualValues(t, buf.Next(len(data)), data) - eval(t, &buf, data, len(data), false) - - // consume 1 byte - assert.EqualValues(t, buf.Next(1), data[:1]) - eval(t, &buf, data[1:], len(data)-1, false) - - // grow 1 byte, the data is copied to the beginning - // slide things down - buf.Grow(1) - eval(t, &buf, data[1:], len(data)-1, false) - - // Done - buf.Done() - eval(t, &buf, []byte(nil), 0, true) -} diff --git a/pkg/tls/cache.go b/pkg/tls/cache.go deleted file mode 100644 index fc8f2c084..000000000 --- a/pkg/tls/cache.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto/x509" - "runtime" - "sync" - "sync/atomic" -) - -type cacheEntry struct { - refs atomic.Int64 - cert *x509.Certificate -} - -// certCache implements an intern table for reference counted x509.Certificates, -// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This -// allows for a single x509.Certificate to be kept in memory and referenced from -// multiple Conns. Returned references should not be mutated by callers. Certificates -// are still safe to use after they are removed from the cache. -// -// Certificates are returned wrapped in a activeCert struct that should be held by -// the caller. When references to the activeCert are freed, the number of references -// to the certificate in the cache is decremented. Once the number of references -// reaches zero, the entry is evicted from the cache. -// -// The main difference between this implementation and CRYPTO_BUFFER_POOL is that -// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data, -// rather than specific structures. Since we only care about x509.Certificates, -// certCache is implemented as a specific cache, rather than a generic one. -// -// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h -// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c -// for the BoringSSL reference. -type certCache struct { - sync.Map -} - -var clientCertCache = new(certCache) - -// activeCert is a handle to a certificate held in the cache. Once there are -// no alive activeCerts for a given certificate, the certificate is removed -// from the cache by a finalizer. -type activeCert struct { - cert *x509.Certificate -} - -// active increments the number of references to the entry, wraps the -// certificate in the entry in a activeCert, and sets the finalizer. -// -// Note that there is a race between active and the finalizer set on the -// returned activeCert, triggered if active is called after the ref count is -// decremented such that refs may be > 0 when evict is called. We consider this -// safe, since the caller holding an activeCert for an entry that is no longer -// in the cache is fine, with the only side effect being the memory overhead of -// there being more than one distinct reference to a certificate alive at once. -func (cc *certCache) active(e *cacheEntry) *activeCert { - e.refs.Add(1) - a := &activeCert{e.cert} - runtime.SetFinalizer(a, func(_ *activeCert) { - if e.refs.Add(-1) == 0 { - cc.evict(e) - } - }) - return a -} - -// evict removes a cacheEntry from the cache. -func (cc *certCache) evict(e *cacheEntry) { - cc.Delete(string(e.cert.Raw)) -} - -// newCert returns a x509.Certificate parsed from der. If there is already a copy -// of the certificate in the cache, a reference to the existing certificate will -// be returned. Otherwise, a fresh certificate will be added to the cache, and -// the reference returned. The returned reference should not be mutated. -func (cc *certCache) newCert(der []byte) (*activeCert, error) { - if entry, ok := cc.Load(string(der)); ok { - return cc.active(entry.(*cacheEntry)), nil - } - - cert, err := x509.ParseCertificate(der) - if err != nil { - return nil, err - } - - entry := &cacheEntry{cert: cert} - if entry, loaded := cc.LoadOrStore(string(der), entry); loaded { - return cc.active(entry.(*cacheEntry)), nil - } - return cc.active(entry), nil -} diff --git a/pkg/tls/cache_test.go b/pkg/tls/cache_test.go deleted file mode 100644 index 284673419..000000000 --- a/pkg/tls/cache_test.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "encoding/pem" - "fmt" - "runtime" - "testing" - "time" -) - -func TestCertCache(t *testing.T) { - cc := certCache{} - p, _ := pem.Decode([]byte(rsaCertPEM)) - if p == nil { - t.Fatal("Failed to decode certificate") - } - - certA, err := cc.newCert(p.Bytes) - if err != nil { - t.Fatalf("newCert failed: %s", err) - } - certB, err := cc.newCert(p.Bytes) - if err != nil { - t.Fatalf("newCert failed: %s", err) - } - if certA.cert != certB.cert { - t.Fatal("newCert returned a unique reference for a duplicate certificate") - } - - if entry, ok := cc.Load(string(p.Bytes)); !ok { - t.Fatal("cache does not contain expected entry") - } else { - if refs := entry.(*cacheEntry).refs.Load(); refs != 2 { - t.Fatalf("unexpected number of references: got %d, want 2", refs) - } - } - - timeoutRefCheck := func(t *testing.T, key string, count int64) { - t.Helper() - c := time.After(4 * time.Second) - for { - select { - case <-c: - t.Fatal("timed out waiting for expected ref count") - default: - e, ok := cc.Load(key) - if !ok && count != 0 { - t.Fatal("cache does not contain expected key") - } else if count == 0 && !ok { - return - } - - if e.(*cacheEntry).refs.Load() == count { - return - } - } - } - } - - // Keep certA alive until at least now, so that we can - // purposefully nil it and force the finalizer to be - // called. - runtime.KeepAlive(certA) - certA = nil - runtime.GC() - - timeoutRefCheck(t, string(p.Bytes), 1) - - // Keep certB alive until at least now, so that we can - // purposefully nil it and force the finalizer to be - // called. - runtime.KeepAlive(certB) - certB = nil - runtime.GC() - - timeoutRefCheck(t, string(p.Bytes), 0) -} - -func BenchmarkCertCache(b *testing.B) { - p, _ := pem.Decode([]byte(rsaCertPEM)) - if p == nil { - b.Fatal("Failed to decode certificate") - } - - cc := certCache{} - b.ReportAllocs() - b.ResetTimer() - // We expect that calling newCert additional times after - // the initial call should not cause additional allocations. - for extra := 0; extra < 4; extra++ { - b.Run(fmt.Sprint(extra), func(b *testing.B) { - actives := make([]*activeCert, extra+1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - var err error - actives[0], err = cc.newCert(p.Bytes) - if err != nil { - b.Fatal(err) - } - for j := 0; j < extra; j++ { - actives[j+1], err = cc.newCert(p.Bytes) - if err != nil { - b.Fatal(err) - } - } - for j := 0; j < extra+1; j++ { - actives[j] = nil - } - runtime.GC() - } - }) - } -} diff --git a/pkg/tls/cipher_suites.go b/pkg/tls/cipher_suites.go deleted file mode 100644 index 3077e0ab7..000000000 --- a/pkg/tls/cipher_suites.go +++ /dev/null @@ -1,702 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/des" - "crypto/hmac" - "crypto/rc4" - "crypto/sha1" - "crypto/sha256" - "fmt" - "hash" - "runtime" - - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/sys/cpu" - "github.com/panjf2000/gnet/v2/internal/boring" -) - -// CipherSuite is a TLS cipher suite. Note that most functions in this package -// accept and expose cipher suite IDs instead of this type. -type CipherSuite struct { - ID uint16 - Name string - - // Supported versions is the list of TLS protocol versions that can - // negotiate this cipher suite. - SupportedVersions []uint16 - - // Insecure is true if the cipher suite has known security issues - // due to its primitives, design, or implementation. - Insecure bool -} - -var ( - supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} - supportedOnlyTLS12 = []uint16{VersionTLS12} - supportedOnlyTLS13 = []uint16{VersionTLS13} -) - -// CipherSuites returns a list of cipher suites currently implemented by this -// package, excluding those with security issues, which are returned by -// InsecureCipherSuites. -// -// The list is sorted by ID. Note that the default cipher suites selected by -// this package might depend on logic that can't be captured by a static list, -// and might not match those returned by this function. -func CipherSuites() []*CipherSuite { - return []*CipherSuite{ - {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - - {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, - {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, - {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, - - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - } -} - -// InsecureCipherSuites returns a list of cipher suites currently implemented by -// this package and which have security issues. -// -// Most applications should not use the cipher suites in this list, and should -// only use those returned by CipherSuites. -func InsecureCipherSuites() []*CipherSuite { - // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See - // cipherSuitesPreferenceOrder for details. - return []*CipherSuite{ - {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - } -} - -// CipherSuiteName returns the standard name for the passed cipher suite ID -// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation -// of the ID value if the cipher suite is not implemented by this package. -func CipherSuiteName(id uint16) string { - for _, c := range CipherSuites() { - if c.ID == id { - return c.Name - } - } - for _, c := range InsecureCipherSuites() { - if c.ID == id { - return c.Name - } - } - return fmt.Sprintf("0x%04X", id) -} - -const ( - // suiteECDHE indicates that the cipher suite involves elliptic curve - // Diffie-Hellman. This means that it should only be selected when the - // client indicates that it supports ECC with a curve and point format - // that we're happy with. - suiteECDHE = 1 << iota - // suiteECSign indicates that the cipher suite involves an ECDSA or - // EdDSA signature and therefore may only be selected when the server's - // certificate is ECDSA or EdDSA. If this is not set then the cipher suite - // is RSA based. - suiteECSign - // suiteTLS12 indicates that the cipher suite should only be advertised - // and accepted when using TLS 1.2. - suiteTLS12 - // suiteSHA384 indicates that the cipher suite uses SHA384 as the - // handshake hash. - suiteSHA384 -) - -// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange -// mechanism, as well as the cipher+MAC pair or the AEAD. -type cipherSuite struct { - id uint16 - // the lengths, in bytes, of the key material needed for each component. - keyLen int - macLen int - ivLen int - ka func(version uint16) keyAgreement - // flags is a bitmask of the suite* values, above. - flags int - cipher func(key, iv []byte, isRead bool) interface{} - mac func(key []byte) hash.Hash - aead func(key, fixedNonce []byte) aead -} - -var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, -} - -// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which -// is also in supportedIDs and passes the ok filter. -func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { - for _, id := range ids { - candidate := cipherSuiteByID(id) - if candidate == nil || !ok(candidate) { - continue - } - - for _, suppID := range supportedIDs { - if id == suppID { - return candidate - } - } - } - return nil -} - -// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash -// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. -type cipherSuiteTLS13 struct { - id uint16 - keyLen int - aead func(key, fixedNonce []byte) aead - hash crypto.Hash -} - -var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. - {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, - {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, - {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, -} - -// cipherSuitesPreferenceOrder is the order in which we'll select (on the -// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. -// -// Cipher suites are filtered but not reordered based on the application and -// peer's preferences, meaning we'll never select a suite lower in this list if -// any higher one is available. This makes it more defensible to keep weaker -// cipher suites enabled, especially on the server side where we get the last -// word, since there are no known downgrade attacks on cipher suites selection. -// -// The list is sorted by applying the following priority rules, stopping at the -// first (most important) applicable one: -// -// - Anything else comes before RC4 -// -// RC4 has practically exploitable biases. See https://www.rc4nomore.com. -// -// - Anything else comes before CBC_SHA256 -// -// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 -// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. -// -// - Anything else comes before 3DES -// -// 3DES has 64-bit blocks, which makes it fundamentally susceptible to -// birthday attacks. See https://sweet32.info. -// -// - ECDHE comes before anything else -// -// Once we got the broken stuff out of the way, the most important -// property a cipher suite can have is forward secrecy. We don't -// implement FFDHE, so that means ECDHE. -// -// - AEADs come before CBC ciphers -// -// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites -// are fundamentally fragile, and suffered from an endless sequence of -// padding oracle attacks. See https://eprint.iacr.org/2015/1129, -// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and -// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. -// -// - AES comes before ChaCha20 -// -// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster -// than ChaCha20Poly1305. -// -// When AES hardware is not available, AES-128-GCM is one or more of: much -// slower, way more complex, and less safe (because not constant time) -// than ChaCha20Poly1305. -// -// We use this list if we think both peers have AES hardware, and -// cipherSuitesPreferenceOrderNoAES otherwise. -// -// - AES-128 comes before AES-256 -// -// The only potential advantages of AES-256 are better multi-target -// margins, and hypothetical post-quantum properties. Neither apply to -// TLS, and AES-256 is slower due to its four extra rounds (which don't -// contribute to the advantages above). -// -// - ECDSA comes before RSA -// -// The relative order of ECDSA and RSA cipher suites doesn't matter, -// as they depend on the certificate. Pick one to get a stable order. -var cipherSuitesPreferenceOrder = []uint16{ - // AEADs w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - - // CBC w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - - // AEADs w/o ECDHE - TLS_RSA_WITH_AES_128_GCM_SHA256, - TLS_RSA_WITH_AES_256_GCM_SHA384, - - // CBC w/o ECDHE - TLS_RSA_WITH_AES_128_CBC_SHA, - TLS_RSA_WITH_AES_256_CBC_SHA, - - // 3DES - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_RSA_WITH_3DES_EDE_CBC_SHA, - - // CBC_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - - // RC4 - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -var cipherSuitesPreferenceOrderNoAES = []uint16{ - // ChaCha20Poly1305 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - - // AES-GCM w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - - // The rest of cipherSuitesPreferenceOrder. - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - TLS_RSA_WITH_AES_128_GCM_SHA256, - TLS_RSA_WITH_AES_256_GCM_SHA384, - TLS_RSA_WITH_AES_128_CBC_SHA, - TLS_RSA_WITH_AES_256_CBC_SHA, - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -// disabledCipherSuites are not used unless explicitly listed in -// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. -var disabledCipherSuites = []uint16{ - // CBC_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - - // RC4 - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -var ( - defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) - defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] -) - -// defaultCipherSuitesTLS13 is also the preference order, since there are no -// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as -// cipherSuitesPreferenceOrder applies. -var defaultCipherSuitesTLS13 = []uint16{ - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - TLS_CHACHA20_POLY1305_SHA256, -} - -var defaultCipherSuitesTLS13NoAES = []uint16{ - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, -} - -var ( - hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ - hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && - (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - - hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || - runtime.GOARCH == "arm64" && hasGCMAsmARM64 || - runtime.GOARCH == "s390x" && hasGCMAsmS390X -) - -var aesgcmCiphers = map[uint16]bool{ - // TLS 1.2 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, - // TLS 1.3 - TLS_AES_128_GCM_SHA256: true, - TLS_AES_256_GCM_SHA384: true, -} - -var nonAESGCMAEADCiphers = map[uint16]bool{ - // TLS 1.2 - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, - // TLS 1.3 - TLS_CHACHA20_POLY1305_SHA256: true, -} - -// aesgcmPreferred returns whether the first known cipher in the preference list -// is an AES-GCM cipher, implying the peer has hardware support for it. -func aesgcmPreferred(ciphers []uint16) bool { - for _, cID := range ciphers { - if c := cipherSuiteByID(cID); c != nil { - return aesgcmCiphers[cID] - } - if c := cipherSuiteTLS13ByID(cID); c != nil { - return aesgcmCiphers[cID] - } - } - return false -} - -func cipherRC4(key, iv []byte, isRead bool) any { - cipher, _ := rc4.NewCipher(key) - return cipher -} - -func cipher3DES(key, iv []byte, isRead bool) interface{} { - block, _ := des.NewTripleDESCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -func cipherAES(key, iv []byte, isRead bool) interface{} { - block, _ := aes.NewCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -// macSHA1 returns a SHA-1 based constant time MAC. -func macSHA1(key []byte) hash.Hash { - h := sha1.New - // The BoringCrypto SHA1 does not have a constant-time - // checksum function, so don't try to use it. - if !boring.Enabled { - h = newConstantTimeHash(h) - } - return hmac.New(h, key) -} - -// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and -// is currently only used in disabled-by-default cipher suites. -func macSHA256(key []byte) hash.Hash { - return hmac.New(sha256.New, key) -} - -type aead interface { - cipher.AEAD - - // explicitNonceLen returns the number of bytes of explicit nonce - // included in each record. This is eight for older AEADs and - // zero for modern ones. - explicitNonceLen() int -} - -const ( - aeadNonceLength = 12 - noncePrefixLength = 4 -) - -// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to -// each call. -type prefixNonceAEAD struct { - // nonce contains the fixed part of the nonce in the first four bytes. - nonce [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } -func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } - -func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - copy(f.nonce[4:], nonce) - return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) -} - -func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - copy(f.nonce[4:], nonce) - return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) -} - -// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce -// before each call. -type xorNonceAEAD struct { - nonceMask [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number -func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } - -func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result -} - -func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result, err -} - -func aeadAESGCM(key, noncePrefix []byte) aead { - if len(noncePrefix) != noncePrefixLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - var aead cipher.AEAD - if boring.Enabled { - aead, err = boring.NewGCMTLS(aes) - } else { - boring.Unreachable() - aead, err = cipher.NewGCM(aes) - } - if err != nil { - panic(err) - } - - ret := &prefixNonceAEAD{aead: aead} - copy(ret.nonce[:], noncePrefix) - return ret -} - -func aeadAESGCMTLS13(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -func aeadChaCha20Poly1305(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aead, err := chacha20poly1305.New(key) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -type constantTimeHash interface { - hash.Hash - ConstantTimeSum(b []byte) []byte -} - -// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces -// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. -type cthWrapper struct { - h constantTimeHash -} - -func (c *cthWrapper) Size() int { return c.h.Size() } -func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } -func (c *cthWrapper) Reset() { c.h.Reset() } -func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } -func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } - -func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { - boring.Unreachable() - return func() hash.Hash { - return &cthWrapper{h().(constantTimeHash)} - } -} - -// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. -func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { - h.Reset() - h.Write(seq) - h.Write(header) - h.Write(data) - res := h.Sum(out) - if extra != nil { - h.Write(extra) - } - return res -} - -func rsaKA(version uint16) keyAgreement { - return rsaKeyAgreement{} -} - -func ecdheECDSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: false, - version: version, - } -} - -func ecdheRSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: true, - version: version, - } -} - -// mutualCipherSuite returns a cipherSuite given a list of supported -// ciphersuites and the id requested by the peer. -func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { - for _, id := range have { - if id == want { - return cipherSuiteByID(id) - } - } - return nil -} - -func cipherSuiteByID(id uint16) *cipherSuite { - for _, cipherSuite := range cipherSuites { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { - for _, id := range have { - if id == want { - return cipherSuiteTLS13ByID(id) - } - } - return nil -} - -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { - for _, cipherSuite := range cipherSuitesTLS13 { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -// A list of cipher suite IDs that are, or have been, implemented by this -// package. -// -// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml -const ( - // TLS 1.0 - 1.2 cipher suites. - TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a - TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f - TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c - TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c - TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a - TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 - - // TLS 1.3 cipher suites. - TLS_AES_128_GCM_SHA256 uint16 = 0x1301 - TLS_AES_256_GCM_SHA384 uint16 = 0x1302 - TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 - - // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator - // that the client is doing version fallback. See RFC 7507. - TLS_FALLBACK_SCSV uint16 = 0x5600 - - // Legacy names for the corresponding cipher suites with the correct _SHA256 - // suffix, retained for backward compatibility. - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 -) diff --git a/pkg/tls/common.go b/pkg/tls/common.go deleted file mode 100644 index 5394d64ac..000000000 --- a/pkg/tls/common.go +++ /dev/null @@ -1,1510 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "container/list" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/sha512" - "crypto/x509" - "errors" - "fmt" - "io" - "net" - "strings" - "sync" - "time" -) - -const ( - VersionTLS10 = 0x0301 - VersionTLS11 = 0x0302 - VersionTLS12 = 0x0303 - VersionTLS13 = 0x0304 - - // Deprecated: SSLv3 is cryptographically broken, and is no longer - // supported by this package. See golang.org/issue/32716. - VersionSSL30 = 0x0300 -) - -const ( - maxPlaintext = 16384 // maximum plaintext payload length - maxCiphertext = 16384 + 2048 // maximum ciphertext payload length - maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 - recordHeaderLen = 5 // record header length - maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) - maxUselessRecords = 16 // maximum number of consecutive non-advancing records -) - -// TLS record types. -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -// TLS handshake message types. -const ( - typeHelloRequest uint8 = 0 - typeClientHello uint8 = 1 - typeServerHello uint8 = 2 - typeNewSessionTicket uint8 = 4 - typeEndOfEarlyData uint8 = 5 - typeEncryptedExtensions uint8 = 8 - typeCertificate uint8 = 11 - typeServerKeyExchange uint8 = 12 - typeCertificateRequest uint8 = 13 - typeServerHelloDone uint8 = 14 - typeCertificateVerify uint8 = 15 - typeClientKeyExchange uint8 = 16 - typeFinished uint8 = 20 - typeCertificateStatus uint8 = 22 - typeKeyUpdate uint8 = 24 - typeNextProtocol uint8 = 67 // Not IANA assigned - typeMessageHash uint8 = 254 // synthetic message -) - -// TLS compression types. -const ( - compressionNone uint8 = 0 -) - -// TLS extension numbers -const ( - extensionServerName uint16 = 0 - extensionStatusRequest uint16 = 5 - extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 - extensionSupportedPoints uint16 = 11 - extensionSignatureAlgorithms uint16 = 13 - extensionALPN uint16 = 16 - extensionSCT uint16 = 18 - extensionSessionTicket uint16 = 35 - extensionPreSharedKey uint16 = 41 - extensionEarlyData uint16 = 42 - extensionSupportedVersions uint16 = 43 - extensionCookie uint16 = 44 - extensionPSKModes uint16 = 45 - extensionCertificateAuthorities uint16 = 47 - extensionSignatureAlgorithmsCert uint16 = 50 - extensionKeyShare uint16 = 51 - extensionRenegotiationInfo uint16 = 0xff01 -) - -// TLS signaling cipher suite values -const ( - scsvRenegotiation uint16 = 0x00ff -) - -// CurveID is the type of a TLS identifier for an elliptic curve. See -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8. -// -// In TLS 1.3, this type is called NamedGroup, but at this time this library -// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7. -type CurveID uint16 - -const ( - CurveP256 CurveID = 23 - CurveP384 CurveID = 24 - CurveP521 CurveID = 25 - X25519 CurveID = 29 -) - -// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. -type keyShare struct { - group CurveID - data []byte -} - -// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. -const ( - pskModePlain uint8 = 0 - pskModeDHE uint8 = 1 -) - -// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved -// session. See RFC 8446, Section 4.2.11. -type pskIdentity struct { - label []byte - obfuscatedTicketAge uint32 -} - -// TLS Elliptic Curve Point Formats -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 -const ( - pointFormatUncompressed uint8 = 0 -) - -// TLS CertificateStatusType (RFC 3546) -const ( - statusTypeOCSP uint8 = 1 -) - -// Certificate types (for certificateRequestMsg) -const ( - certTypeRSASign = 1 - certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. -) - -// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with -// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. -const ( - signaturePKCS1v15 uint8 = iota + 225 - signatureRSAPSS - signatureECDSA - signatureEd25519 -) - -// directSigning is a standard Hash value that signals that no pre-hashing -// should be performed, and that the input should be signed directly. It is the -// hash function associated with the Ed25519 signature scheme. -var directSigning crypto.Hash = 0 - -// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that -// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ -// CertificateRequest. The two fields are merged to match with TLS 1.3. -// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. -var defaultSupportedSignatureAlgorithms = []SignatureScheme{ - PSSWithSHA256, - ECDSAWithP256AndSHA256, - Ed25519, - PSSWithSHA384, - PSSWithSHA512, - PKCS1WithSHA256, - PKCS1WithSHA384, - PKCS1WithSHA512, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - PKCS1WithSHA1, - ECDSAWithSHA1, -} - -// helloRetryRequestRandom is set as the Random value of a ServerHello -// to signal that the message is actually a HelloRetryRequest. -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -// testingOnlyForceDowngradeCanary is set in tests to force the server side to -// include downgrade canaries even if it's using its highers supported version. -var testingOnlyForceDowngradeCanary bool - -// ConnectionState records basic TLS details about the connection. -type ConnectionState struct { - // Version is the TLS version used by the connection (e.g. VersionTLS12). - Version uint16 - - // HandshakeComplete is true if the handshake has concluded. - HandshakeComplete bool - - // DidResume is true if this connection was successfully resumed from a - // previous session with a session ticket or similar mechanism. - DidResume bool - - // CipherSuite is the cipher suite negotiated for the connection (e.g. - // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). - CipherSuite uint16 - - // NegotiatedProtocol is the application protocol negotiated with ALPN. - NegotiatedProtocol string - - // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. - // - // Deprecated: this value is always true. - NegotiatedProtocolIsMutual bool - - // ServerName is the value of the Server Name Indication extension sent by - // the client. It's available both on the server and on the client side. - ServerName string - - // PeerCertificates are the parsed certificates sent by the peer, in the - // order in which they were sent. The first element is the leaf certificate - // that the connection is verified against. - // - // On the client side, it can't be empty. On the server side, it can be - // empty if Config.ClientAuth is not RequireAnyClientCert or - // RequireAndVerifyClientCert. - // - // PeerCertificates and its contents should not be modified. - PeerCertificates []*x509.Certificate - - // VerifiedChains is a list of one or more chains where the first element is - // PeerCertificates[0] and the last element is from Config.RootCAs (on the - // client side) or Config.ClientCAs (on the server side). - // - // On the client side, it's set if Config.InsecureSkipVerify is false. On - // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven - // (and the peer provided a certificate) or RequireAndVerifyClientCert. - // - // VerifiedChains and its contents should not be modified. - VerifiedChains [][]*x509.Certificate - - // SignedCertificateTimestamps is a list of SCTs provided by the peer - // through the TLS handshake for the leaf certificate, if any. - SignedCertificateTimestamps [][]byte - - // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) - // response provided by the peer for the leaf certificate, if any. - OCSPResponse []byte - - // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, - // Section 3). This value will be nil for TLS 1.3 connections and for all - // resumed connections. - // - // Deprecated: there are conditions in which this value might not be unique - // to a connection. See the Security Considerations sections of RFC 5705 and - // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. - TLSUnique []byte - - // ekm is a closure exposed via ExportKeyingMaterial. - ekm func(label string, context []byte, length int) ([]byte, error) -} - -// ExportKeyingMaterial returns length bytes of exported key material in a new -// slice as defined in RFC 5705. If context is nil, it is not used as part of -// the seed. If the connection was set to allow renegotiation via -// Config.Renegotiation, this function will return an error. -func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { - return cs.ekm(label, context, length) -} - -// ClientAuthType declares the policy the server will follow for -// TLS Client Authentication. -type ClientAuthType int - -const ( - // NoClientCert indicates that no client certificate should be requested - // during the handshake, and if any certificates are sent they will not - // be verified. - NoClientCert ClientAuthType = iota - // RequestClientCert indicates that a client certificate should be requested - // during the handshake, but does not require that the client send any - // certificates. - RequestClientCert - // RequireAnyClientCert indicates that a client certificate should be requested - // during the handshake, and that at least one certificate is required to be - // sent by the client, but that certificate is not required to be valid. - RequireAnyClientCert - // VerifyClientCertIfGiven indicates that a client certificate should be requested - // during the handshake, but does not require that the client sends a - // certificate. If the client does send a certificate it is required to be - // valid. - VerifyClientCertIfGiven - // RequireAndVerifyClientCert indicates that a client certificate should be requested - // during the handshake, and that at least one valid certificate is required - // to be sent by the client. - RequireAndVerifyClientCert -) - -// requiresClientCert reports whether the ClientAuthType requires a client -// certificate to be provided. -func requiresClientCert(c ClientAuthType) bool { - switch c { - case RequireAnyClientCert, RequireAndVerifyClientCert: - return true - default: - return false - } -} - -// ClientSessionState contains the state needed by clients to resume TLS -// sessions. -type ClientSessionState struct { - sessionTicket []uint8 // Encrypted ticket used for session resumption with server - vers uint16 // TLS version negotiated for the session - cipherSuite uint16 // Ciphersuite negotiated for the session - masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret - serverCertificates []*x509.Certificate // Certificate chain presented by the server - verifiedChains [][]*x509.Certificate // Certificate chains we built for verification - receivedAt time.Time // When the session ticket was received from the server - ocspResponse []byte // Stapled OCSP response presented by the server - scts [][]byte // SCTs presented by the server - - // TLS 1.3 fields. - nonce []byte // Ticket nonce sent by the server, to derive PSK - useBy time.Time // Expiration of the ticket lifetime as set by the server - ageAdd uint32 // Random obfuscation factor for sending the ticket age -} - -// ClientSessionCache is a cache of ClientSessionState objects that can be used -// by a client to resume a TLS session with a given server. ClientSessionCache -// implementations should expect to be called concurrently from different -// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not -// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which -// are supported via this interface. -type ClientSessionCache interface { - // Get searches for a ClientSessionState associated with the given key. - // On return, ok is true if one was found. - Get(sessionKey string) (session *ClientSessionState, ok bool) - - // Put adds the ClientSessionState to the cache with the given key. It might - // get called multiple times in a connection if a TLS 1.3 server provides - // more than one session ticket. If called with a nil *ClientSessionState, - // it should remove the cache entry. - Put(sessionKey string, cs *ClientSessionState) -} - -//go:generate stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go - -// SignatureScheme identifies a signature algorithm supported by TLS. See -// RFC 8446, Section 4.2.3. -type SignatureScheme uint16 - -const ( - // RSASSA-PKCS1-v1_5 algorithms. - PKCS1WithSHA256 SignatureScheme = 0x0401 - PKCS1WithSHA384 SignatureScheme = 0x0501 - PKCS1WithSHA512 SignatureScheme = 0x0601 - - // RSASSA-PSS algorithms with public key OID rsaEncryption. - PSSWithSHA256 SignatureScheme = 0x0804 - PSSWithSHA384 SignatureScheme = 0x0805 - PSSWithSHA512 SignatureScheme = 0x0806 - - // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. - ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 - ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 - ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 - - // EdDSA algorithms. - Ed25519 SignatureScheme = 0x0807 - - // Legacy signature and hash algorithms for TLS 1.2. - PKCS1WithSHA1 SignatureScheme = 0x0201 - ECDSAWithSHA1 SignatureScheme = 0x0203 -) - -// ClientHelloInfo contains information from a ClientHello message in order to -// guide application logic in the GetCertificate and GetConfigForClient callbacks. -type ClientHelloInfo struct { - // CipherSuites lists the CipherSuites supported by the client (e.g. - // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). - CipherSuites []uint16 - - // ServerName indicates the name of the server requested by the client - // in order to support virtual hosting. ServerName is only set if the - // client is using SNI (see RFC 4366, Section 3.1). - ServerName string - - // SupportedCurves lists the elliptic curves supported by the client. - // SupportedCurves is set only if the Supported Elliptic Curves - // Extension is being used (see RFC 4492, Section 5.1.1). - SupportedCurves []CurveID - - // SupportedPoints lists the point formats supported by the client. - // SupportedPoints is set only if the Supported Point Formats Extension - // is being used (see RFC 4492, Section 5.1.2). - SupportedPoints []uint8 - - // SignatureSchemes lists the signature and hash schemes that the client - // is willing to verify. SignatureSchemes is set only if the Signature - // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). - SignatureSchemes []SignatureScheme - - // SupportedProtos lists the application protocols supported by the client. - // SupportedProtos is set only if the Application-Layer Protocol - // Negotiation Extension is being used (see RFC 7301, Section 3.1). - // - // Servers can select a protocol by setting Config.NextProtos in a - // GetConfigForClient return value. - SupportedProtos []string - - // SupportedVersions lists the TLS versions supported by the client. - // For TLS versions less than 1.3, this is extrapolated from the max - // version advertised by the client, so values other than the greatest - // might be rejected if used. - SupportedVersions []uint16 - - // Conn is the underlying net.Conn for the connection. Do not read - // from, or write to, this connection; that will cause the TLS - // connection to fail. - Conn net.Conn - - // config is embedded by the GetCertificate or GetConfigForClient caller, - // for use with SupportsCertificate. - config *Config - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *ClientHelloInfo) Context() context.Context { - return c.ctx -} - -// CertificateRequestInfo contains information from a server's -// CertificateRequest message, which is used to demand a certificate and proof -// of control from a client. -type CertificateRequestInfo struct { - // AcceptableCAs contains zero or more, DER-encoded, X.501 - // Distinguished Names. These are the names of root or intermediate CAs - // that the server wishes the returned certificate to be signed by. An - // empty slice indicates that the server has no preference. - AcceptableCAs [][]byte - - // SignatureSchemes lists the signature schemes that the server is - // willing to verify. - SignatureSchemes []SignatureScheme - - // Version is the TLS version that was negotiated for this connection. - Version uint16 - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *CertificateRequestInfo) Context() context.Context { - return c.ctx -} - -// RenegotiationSupport enumerates the different levels of support for TLS -// renegotiation. TLS renegotiation is the act of performing subsequent -// handshakes on a connection after the first. This significantly complicates -// the state machine and has been the source of numerous, subtle security -// issues. Initiating a renegotiation is not supported, but support for -// accepting renegotiation requests may be enabled. -// -// Even when enabled, the server may not change its identity between handshakes -// (i.e. the leaf certificate must be the same). Additionally, concurrent -// handshake and application data flow is not permitted so renegotiation can -// only be used with protocols that synchronise with the renegotiation, such as -// HTTPS. -// -// Renegotiation is not defined in TLS 1.3. -type RenegotiationSupport int - -const ( - // RenegotiateNever disables renegotiation. - RenegotiateNever RenegotiationSupport = iota - - // RenegotiateOnceAsClient allows a remote server to request - // renegotiation once per connection. - RenegotiateOnceAsClient - - // RenegotiateFreelyAsClient allows a remote server to repeatedly - // request renegotiation. - RenegotiateFreelyAsClient -) - -// A Config structure is used to configure a TLS client or server. -// After one has been passed to a TLS function it must not be -// modified. A Config may be reused; the tls package will also not -// modify it. -type Config struct { - // Rand provides the source of entropy for nonces and RSA blinding. - // If Rand is nil, TLS uses the cryptographic random reader in package - // crypto/rand. - // The Reader must be safe for use by multiple goroutines. - Rand io.Reader - - // Time returns the current time as the number of seconds since the epoch. - // If Time is nil, TLS uses time.Now. - Time func() time.Time - - // Certificates contains one or more certificate chains to present to the - // other side of the connection. The first certificate compatible with the - // peer's requirements is selected automatically. - // - // Server configurations must set one of Certificates, GetCertificate or - // GetConfigForClient. Clients doing client-authentication may set either - // Certificates or GetClientCertificate. - // - // Note: if there are multiple Certificates, and they don't have the - // optional field Leaf set, certificate selection will incur a significant - // per-handshake performance cost. - Certificates []Certificate - - // NameToCertificate maps from a certificate name to an element of - // Certificates. Note that a certificate name can be of the form - // '*.example.com' and so doesn't have to be a domain name as such. - // - // Deprecated: NameToCertificate only allows associating a single - // certificate with a given name. Leave this field nil to let the library - // select the first compatible chain from Certificates. - NameToCertificate map[string]*Certificate - - // GetCertificate returns a Certificate based on the given - // ClientHelloInfo. It will only be called if the client supplies SNI - // information or if Certificates is empty. - // - // If GetCertificate is nil or returns nil, then the certificate is - // retrieved from NameToCertificate. If NameToCertificate is nil, the - // best element of Certificates will be used. - // - // Once a Certificate is returned it should not be modified. - GetCertificate func(*ClientHelloInfo) (*Certificate, error) - - // GetClientCertificate, if not nil, is called when a server requests a - // certificate from a client. If set, the contents of Certificates will - // be ignored. - // - // If GetClientCertificate returns an error, the handshake will be - // aborted and that error will be returned. Otherwise - // GetClientCertificate must return a non-nil Certificate. If - // Certificate.Certificate is empty then no certificate will be sent to - // the server. If this is unacceptable to the server then it may abort - // the handshake. - // - // GetClientCertificate may be called multiple times for the same - // connection if renegotiation occurs or if TLS 1.3 is in use. - // - // Once a Certificate is returned it should not be modified. - GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) - - // GetConfigForClient, if not nil, is called after a ClientHello is - // received from a client. It may return a non-nil Config in order to - // change the Config that will be used to handle this connection. If - // the returned Config is nil, the original Config will be used. The - // Config returned by this callback may not be subsequently modified. - // - // If GetConfigForClient is nil, the Config passed to Server() will be - // used for all connections. - // - // If SessionTicketKey was explicitly set on the returned Config, or if - // SetSessionTicketKeys was called on the returned Config, those keys will - // be used. Otherwise, the original Config keys will be used (and possibly - // rotated if they are automatically managed). - GetConfigForClient func(*ClientHelloInfo) (*Config, error) - - // VerifyPeerCertificate, if not nil, is called after normal - // certificate verification by either a TLS client or server. It - // receives the raw ASN.1 certificates provided by the peer and also - // any verified chains that normal processing found. If it returns a - // non-nil error, the handshake is aborted and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. If normal verification is disabled by - // setting InsecureSkipVerify, or (for a server) when ClientAuth is - // RequestClientCert or RequireAnyClientCert, then this callback will - // be considered but the verifiedChains argument will always be nil. - // - // verifiedChains and its contents should not be modified. - VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - - // VerifyConnection, if not nil, is called after normal certificate - // verification and after VerifyPeerCertificate by either a TLS client - // or server. If it returns a non-nil error, the handshake is aborted - // and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. This callback will run for all connections - // regardless of InsecureSkipVerify or ClientAuth settings. - VerifyConnection func(ConnectionState) error - - // RootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *x509.CertPool - - // NextProtos is a list of supported application level protocols, in - // order of preference. If both peers support ALPN, the selected - // protocol will be one from this list, and the connection will fail - // if there is no mutually supported protocol. If NextProtos is empty - // or the peer doesn't support ALPN, the connection will succeed and - // ConnectionState.NegotiatedProtocol will be empty. - NextProtos []string - - // ServerName is used to verify the hostname on the returned - // certificates unless InsecureSkipVerify is given. It is also included - // in the client's handshake to support virtual hosting unless it is - // an IP address. - ServerName string - - // ClientAuth determines the server's policy for - // TLS Client Authentication. The default is NoClientCert. - ClientAuth ClientAuthType - - // ClientCAs defines the set of root certificate authorities - // that servers use if required to verify a client certificate - // by the policy in ClientAuth. - ClientCAs *x509.CertPool - - // InsecureSkipVerify controls whether a client verifies the server's - // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls - // accepts any certificate presented by the server and any host name in that - // certificate. In this mode, TLS is susceptible to machine-in-the-middle - // attacks unless custom verification is used. This should be used only for - // testing or in combination with VerifyConnection or VerifyPeerCertificate. - InsecureSkipVerify bool - - // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of - // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. - // - // If CipherSuites is nil, a safe default list is used. The default cipher - // suites might change over time. - CipherSuites []uint16 - - // PreferServerCipherSuites is a legacy field and has no effect. - // - // It used to control whether the server would follow the client's or the - // server's preference. Servers now select the best mutually supported - // cipher suite based on logic that takes into account inferred client - // hardware, server hardware, and security. - // - // Deprecated: PreferServerCipherSuites is ignored. - PreferServerCipherSuites bool - - // SessionTicketsDisabled may be set to true to disable session ticket and - // PSK (resumption) support. Note that on clients, session ticket support is - // also disabled if ClientSessionCache is nil. - SessionTicketsDisabled bool - - // SessionTicketKey is used by TLS servers to provide session resumption. - // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled - // with random data before the first server handshake. - // - // Deprecated: if this field is left at zero, session ticket keys will be - // automatically rotated every day and dropped after seven days. For - // customizing the rotation schedule or synchronizing servers that are - // terminating connections for the same host, use SetSessionTicketKeys. - SessionTicketKey [32]byte - - // ClientSessionCache is a cache of ClientSessionState entries for TLS - // session resumption. It is only used by clients. - ClientSessionCache ClientSessionCache - - // MinVersion contains the minimum TLS version that is acceptable. - // - // By default, TLS 1.2 is currently used as the minimum when acting as a - // client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum - // supported by this package, both as a client and as a server. - // - // The client-side default can temporarily be reverted to TLS 1.0 by - // including the value "x509sha1=1" in the GODEBUG environment variable. - // Note that this option will be removed in Go 1.19 (but it will still be - // possible to set this field to VersionTLS10 explicitly). - MinVersion uint16 - - // MaxVersion contains the maximum TLS version that is acceptable. - // - // By default, the maximum version supported by this package is used, - // which is currently TLS 1.3. - MaxVersion uint16 - - // CurvePreferences contains the elliptic curves that will be used in - // an ECDHE handshake, in preference order. If empty, the default will - // be used. The client will use the first preference as the type for - // its key share in TLS 1.3. This may change in the future. - CurvePreferences []CurveID - - // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. - // When true, the largest possible TLS record size is always used. When - // false, the size of TLS records may be adjusted in an attempt to - // improve latency. - DynamicRecordSizingDisabled bool - - // Renegotiation controls what types of renegotiation are supported. - // The default, none, is correct for the vast majority of applications. - Renegotiation RenegotiationSupport - - // KeyLogWriter optionally specifies a destination for TLS master secrets - // in NSS key log format that can be used to allow external programs - // such as Wireshark to decrypt TLS connections. - // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. - // Use of KeyLogWriter compromises security and should only be - // used for debugging. - KeyLogWriter io.Writer - - // mutex protects sessionTicketKeys and autoSessionTicketKeys. - mutex sync.RWMutex - // sessionTicketKeys contains zero or more ticket keys. If set, it means - // the keys were set with SessionTicketKey or SetSessionTicketKeys. The - // first key is used for new tickets and any subsequent keys can be used to - // decrypt old tickets. The slice contents are not protected by the mutex - // and are immutable. - sessionTicketKeys []ticketKey - // autoSessionTicketKeys is like sessionTicketKeys but is owned by the - // auto-rotation logic. See Config.ticketKeys. - autoSessionTicketKeys []ticketKey -} - -const ( - // ticketKeyNameLen is the number of bytes of identifier that is prepended to - // an encrypted session ticket in order to identify the key used to encrypt it. - ticketKeyNameLen = 16 - - // ticketKeyLifetime is how long a ticket key remains valid and can be used to - // resume a client connection. - ticketKeyLifetime = 7 * 24 * time.Hour // 7 days - - // ticketKeyRotation is how often the server should rotate the session ticket key - // that is used for new tickets. - ticketKeyRotation = 24 * time.Hour -) - -// ticketKey is the internal representation of a session ticket key. -type ticketKey struct { - // keyName is an opaque byte string that serves to identify the session - // ticket key. It's exposed as plaintext in every session ticket. - keyName [ticketKeyNameLen]byte - aesKey [16]byte - hmacKey [16]byte - // created is the time at which this ticket key was created. See Config.ticketKeys. - created time.Time -} - -// ticketKeyFromBytes converts from the external representation of a session -// ticket key to a ticketKey. Externally, session ticket keys are 32 random -// bytes and this function expands that into sufficient name and key material. -func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { - hashed := sha512.Sum512(b[:]) - copy(key.keyName[:], hashed[:ticketKeyNameLen]) - copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) - copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) - key.created = c.time() - return key -} - -// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session -// ticket, and the lifetime we set for tickets we send. -const maxSessionTicketLifetime = 7 * 24 * time.Hour - -// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is -// being used concurrently by a TLS client or server. -func (c *Config) Clone() *Config { - if c == nil { - return nil - } - c.mutex.RLock() - defer c.mutex.RUnlock() - return &Config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - GetClientCertificate: c.GetClientCertificate, - GetConfigForClient: c.GetConfigForClient, - VerifyPeerCertificate: c.VerifyPeerCertificate, - VerifyConnection: c.VerifyConnection, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - KeyLogWriter: c.KeyLogWriter, - sessionTicketKeys: c.sessionTicketKeys, - autoSessionTicketKeys: c.autoSessionTicketKeys, - } -} - -// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was -// randomized for backwards compatibility but is not in use. -var deprecatedSessionTicketKey = []byte("DEPRECATED") - -// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is -// randomized if empty, and that sessionTicketKeys is populated from it otherwise. -func (c *Config) initLegacySessionTicketKeyRLocked() { - // Don't write if SessionTicketKey is already defined as our deprecated string, - // or if it is defined by the user but sessionTicketKeys is already set. - if c.SessionTicketKey != [32]byte{} && - (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { - return - } - - // We need to write some data, so get an exclusive lock and re-check any conditions. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - if c.SessionTicketKey == [32]byte{} { - if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { - panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) - } - // Write the deprecated prefix at the beginning so we know we created - // it. This key with the DEPRECATED prefix isn't used as an actual - // session ticket key, and is only randomized in case the application - // reuses it for some reason. - copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) - } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { - c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} - } - -} - -// ticketKeys returns the ticketKeys for this connection. -// If configForClient has explicitly set keys, those will -// be returned. Otherwise, the keys on c will be used and -// may be rotated if auto-managed. -// During rotation, any expired session ticket keys are deleted from -// c.sessionTicketKeys. If the session ticket key that is currently -// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) -// is not fresh, then a new session ticket key will be -// created and prepended to c.sessionTicketKeys. -func (c *Config) ticketKeys(configForClient *Config) []ticketKey { - // If the ConfigForClient callback returned a Config with explicitly set - // keys, use those, otherwise just use the original Config. - if configForClient != nil { - configForClient.mutex.RLock() - if configForClient.SessionTicketsDisabled { - return nil - } - configForClient.initLegacySessionTicketKeyRLocked() - if len(configForClient.sessionTicketKeys) != 0 { - ret := configForClient.sessionTicketKeys - configForClient.mutex.RUnlock() - return ret - } - configForClient.mutex.RUnlock() - } - - c.mutex.RLock() - defer c.mutex.RUnlock() - if c.SessionTicketsDisabled { - return nil - } - c.initLegacySessionTicketKeyRLocked() - if len(c.sessionTicketKeys) != 0 { - return c.sessionTicketKeys - } - // Fast path for the common case where the key is fresh enough. - if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { - return c.autoSessionTicketKeys - } - - // autoSessionTicketKeys are managed by auto-rotation. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - // Re-check the condition in case it changed since obtaining the new lock. - if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { - var newKey [32]byte - if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { - panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) - } - valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) - valid = append(valid, c.ticketKeyFromBytes(newKey)) - for _, k := range c.autoSessionTicketKeys { - // While rotating the current key, also remove any expired ones. - if c.time().Sub(k.created) < ticketKeyLifetime { - valid = append(valid, k) - } - } - c.autoSessionTicketKeys = valid - } - return c.autoSessionTicketKeys -} - -// SetSessionTicketKeys updates the session ticket keys for a server. -// -// The first key will be used when creating new tickets, while all keys can be -// used for decrypting tickets. It is safe to call this function while the -// server is running in order to rotate the session ticket keys. The function -// will panic if keys is empty. -// -// Calling this function will turn off automatic session ticket key rotation. -// -// If multiple servers are terminating connections for the same host they should -// all have the same session ticket keys. If the session ticket keys leaks, -// previously recorded and future TLS connections using those keys might be -// compromised. -func (c *Config) SetSessionTicketKeys(keys [][32]byte) { - if len(keys) == 0 { - panic("tls: keys must have at least one key") - } - - newKeys := make([]ticketKey, len(keys)) - for i, bytes := range keys { - newKeys[i] = c.ticketKeyFromBytes(bytes) - } - - c.mutex.Lock() - c.sessionTicketKeys = newKeys - c.mutex.Unlock() -} - -func (c *Config) rand() io.Reader { - r := c.Rand - if r == nil { - return rand.Reader - } - return r -} - -func (c *Config) time() time.Time { - t := c.Time - if t == nil { - t = time.Now - } - return t() -} - -func (c *Config) cipherSuites() []uint16 { - if needFIPS() { - return fipsCipherSuites(c) - } - if c.CipherSuites != nil { - return c.CipherSuites - } - return defaultCipherSuites -} - -var supportedVersions = []uint16{ - VersionTLS13, - VersionTLS12, - VersionTLS11, - VersionTLS10, -} - -// roleClient and roleServer are meant to call supportedVersions and parents -// with more readability at the callsite. -const roleClient = true -const roleServer = false - -func (c *Config) supportedVersions(isClient bool) []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) { - continue - } - if (c == nil || c.MinVersion == 0) && - isClient && v < VersionTLS12 { - continue - } - if c != nil && c.MinVersion != 0 && v < c.MinVersion { - continue - } - if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -func (c *Config) maxSupportedVersion(isClient bool) uint16 { - supportedVersions := c.supportedVersions(isClient) - if len(supportedVersions) == 0 { - return 0 - } - return supportedVersions[0] -} - -// supportedVersionsFromMax returns a list of supported versions derived from a -// legacy maximum version value. Note that only versions supported by this -// library are returned. Any newer peer will use supportedVersions anyway. -func supportedVersionsFromMax(maxVersion uint16) []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if v > maxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} - -func (c *Config) curvePreferences() []CurveID { - if needFIPS() { - return fipsCurvePreferences(c) - } - if c == nil || len(c.CurvePreferences) == 0 { - return defaultCurvePreferences - } - return c.CurvePreferences -} - -func (c *Config) supportsCurve(curve CurveID) bool { - for _, cc := range c.curvePreferences() { - if cc == curve { - return true - } - } - return false -} - -// mutualVersion returns the protocol version to use given the advertised -// versions of the peer. Priority is given to the peer preference order. -func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) { - supportedVersions := c.supportedVersions(isClient) - for _, peerVersion := range peerVersions { - for _, v := range supportedVersions { - if v == peerVersion { - return v, true - } - } - } - return 0, false -} - -var errNoCertificates = errors.New("tls: no certificates configured") - -// getCertificate returns the best certificate for the given ClientHelloInfo, -// defaulting to the first element of c.Certificates. -func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { - if c.GetCertificate != nil && - (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { - cert, err := c.GetCertificate(clientHello) - if cert != nil || err != nil { - return cert, err - } - } - - if len(c.Certificates) == 0 { - return nil, errNoCertificates - } - - if len(c.Certificates) == 1 { - // There's only one choice, so no point doing any work. - return &c.Certificates[0], nil - } - - if c.NameToCertificate != nil { - name := strings.ToLower(clientHello.ServerName) - if cert, ok := c.NameToCertificate[name]; ok { - return cert, nil - } - if len(name) > 0 { - labels := strings.Split(name, ".") - labels[0] = "*" - wildcardName := strings.Join(labels, ".") - if cert, ok := c.NameToCertificate[wildcardName]; ok { - return cert, nil - } - } - } - - for _, cert := range c.Certificates { - if err := clientHello.SupportsCertificate(&cert); err == nil { - return &cert, nil - } - } - - // If nothing matches, return the first certificate. - return &c.Certificates[0], nil -} - -// SupportsCertificate returns nil if the provided certificate is supported by -// the client that sent the ClientHello. Otherwise, it returns an error -// describing the reason for the incompatibility. -// -// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate -// callback, this method will take into account the associated Config. Note that -// if GetConfigForClient returns a different Config, the change can't be -// accounted for by this method. -// -// This function will call x509.ParseCertificate unless c.Leaf is set, which can -// incur a significant performance cost. -func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error { - // Note we don't currently support certificate_authorities nor - // signature_algorithms_cert, and don't check the algorithms of the - // signatures on the chain (which anyway are a SHOULD, see RFC 8446, - // Section 4.4.2.2). - - config := chi.config - if config == nil { - config = &Config{} - } - vers, ok := config.mutualVersion(roleServer, chi.SupportedVersions) - if !ok { - return errors.New("no mutually supported protocol versions") - } - - // If the client specified the name they are trying to connect to, the - // certificate needs to be valid for it. - if chi.ServerName != "" { - x509Cert, err := c.leaf() - if err != nil { - return fmt.Errorf("failed to parse certificate: %w", err) - } - if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { - return fmt.Errorf("certificate is not valid for requested server name: %w", err) - } - } - - // supportsRSAFallback returns nil if the certificate and connection support - // the static RSA key exchange, and unsupported otherwise. The logic for - // supporting static RSA is completely disjoint from the logic for - // supporting signed key exchanges, so we just check it as a fallback. - supportsRSAFallback := func(unsupported error) error { - // TLS 1.3 dropped support for the static RSA key exchange. - if vers == VersionTLS13 { - return unsupported - } - // The static RSA key exchange works by decrypting a challenge with the - // RSA private key, not by signing, so check the PrivateKey implements - // crypto.Decrypter, like *rsa.PrivateKey does. - if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { - if _, ok := priv.Public().(*rsa.PublicKey); !ok { - return unsupported - } - } else { - return unsupported - } - // Finally, there needs to be a mutual cipher suite that uses the static - // RSA key exchange instead of ECDHE. - rsaCipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - return false - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if rsaCipherSuite == nil { - return unsupported - } - return nil - } - - // If the client sent the signature_algorithms extension, ensure it supports - // schemes we can use with this certificate and TLS version. - if len(chi.SignatureSchemes) > 0 { - if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { - return supportsRSAFallback(err) - } - } - - // In TLS 1.3 we are done because supported_groups is only relevant to the - // ECDHE computation, point format negotiation is removed, cipher suites are - // only relevant to the AEAD choice, and static RSA does not exist. - if vers == VersionTLS13 { - return nil - } - - // The only signed key exchange we support is ECDHE. - if !supportsECDHE(config, chi.SupportedCurves, chi.SupportedPoints) { - return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) - } - - var ecdsaCipherSuite bool - if priv, ok := c.PrivateKey.(crypto.Signer); ok { - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - var curve CurveID - switch pub.Curve { - case elliptic.P256(): - curve = CurveP256 - case elliptic.P384(): - curve = CurveP384 - case elliptic.P521(): - curve = CurveP521 - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - var curveOk bool - for _, c := range chi.SupportedCurves { - if c == curve && config.supportsCurve(c) { - curveOk = true - break - } - } - if !curveOk { - return errors.New("client doesn't support certificate curve") - } - ecdsaCipherSuite = true - case ed25519.PublicKey: - if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { - return errors.New("connection doesn't support Ed25519") - } - ecdsaCipherSuite = true - case *rsa.PublicKey: - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - } else { - return supportsRSAFallback(unsupportedCertificateError(c)) - } - - // Make sure that there is a mutually supported cipher suite that works with - // this certificate. Cipher suite selection will then apply the logic in - // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. - cipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE == 0 { - return false - } - if c.flags&suiteECSign != 0 { - if !ecdsaCipherSuite { - return false - } - } else { - if ecdsaCipherSuite { - return false - } - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if cipherSuite == nil { - return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) - } - - return nil -} - -// SupportsCertificate returns nil if the provided certificate is supported by -// the server that sent the CertificateRequest. Otherwise, it returns an error -// describing the reason for the incompatibility. -func (cri *CertificateRequestInfo) SupportsCertificate(c *Certificate) error { - if _, err := selectSignatureScheme(cri.Version, c, cri.SignatureSchemes); err != nil { - return err - } - - if len(cri.AcceptableCAs) == 0 { - return nil - } - - for j, cert := range c.Certificate { - x509Cert := c.Leaf - // Parse the certificate if this isn't the leaf node, or if - // chain.Leaf was nil. - if j != 0 || x509Cert == nil { - var err error - if x509Cert, err = x509.ParseCertificate(cert); err != nil { - return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) - } - } - - for _, ca := range cri.AcceptableCAs { - if bytes.Equal(x509Cert.RawIssuer, ca) { - return nil - } - } - } - return errors.New("chain is not signed by an acceptable CA") -} - -// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate -// from the CommonName and SubjectAlternateName fields of each of the leaf -// certificates. -// -// Deprecated: NameToCertificate only allows associating a single certificate -// with a given name. Leave that field nil to let the library select the first -// compatible chain from Certificates. -func (c *Config) BuildNameToCertificate() { - c.NameToCertificate = make(map[string]*Certificate) - for i := range c.Certificates { - cert := &c.Certificates[i] - x509Cert, err := cert.leaf() - if err != nil { - continue - } - // If SANs are *not* present, some clients will consider the certificate - // valid for the name in the Common Name. - if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { - c.NameToCertificate[x509Cert.Subject.CommonName] = cert - } - for _, san := range x509Cert.DNSNames { - c.NameToCertificate[san] = cert - } - } -} - -const ( - keyLogLabelTLS12 = "CLIENT_RANDOM" - keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" - keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" -) - -func (c *Config) writeKeyLog(label string, clientRandom, secret []byte) error { - if c.KeyLogWriter == nil { - return nil - } - - logLine := fmt.Appendf(nil, "%s %x %x\n", label, clientRandom, secret) - - writerMutex.Lock() - _, err := c.KeyLogWriter.Write(logLine) - writerMutex.Unlock() - - return err -} - -// writerMutex protects all KeyLogWriters globally. It is rarely enabled, -// and is only for debugging, so a global mutex saves space. -var writerMutex sync.Mutex - -// A Certificate is a chain of one or more certificates, leaf first. -type Certificate struct { - Certificate [][]byte - // PrivateKey contains the private key corresponding to the public key in - // Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey. - // For a server up to TLS 1.2, it can also implement crypto.Decrypter with - // an RSA PublicKey. - PrivateKey crypto.PrivateKey - // SupportedSignatureAlgorithms is an optional list restricting what - // signature algorithms the PrivateKey can be used for. - SupportedSignatureAlgorithms []SignatureScheme - // OCSPStaple contains an optional OCSP response which will be served - // to clients that request it. - OCSPStaple []byte - // SignedCertificateTimestamps contains an optional list of Signed - // Certificate Timestamps which will be served to clients that request it. - SignedCertificateTimestamps [][]byte - // Leaf is the parsed form of the leaf certificate, which may be initialized - // using x509.ParseCertificate to reduce per-handshake processing. If nil, - // the leaf certificate will be parsed as needed. - Leaf *x509.Certificate -} - -// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing -// the corresponding c.Certificate[0]. -func (c *Certificate) leaf() (*x509.Certificate, error) { - if c.Leaf != nil { - return c.Leaf, nil - } - return x509.ParseCertificate(c.Certificate[0]) -} - -type handshakeMessage interface { - marshal() ([]byte, error) - unmarshal([]byte) bool -} - -// lruSessionCache is a ClientSessionCache implementation that uses an LRU -// caching strategy. -type lruSessionCache struct { - sync.Mutex - - m map[string]*list.Element - q *list.List - capacity int -} - -type lruSessionCacheEntry struct { - sessionKey string - state *ClientSessionState -} - -// NewLRUClientSessionCache returns a ClientSessionCache with the given -// capacity that uses an LRU strategy. If capacity is < 1, a default capacity -// is used instead. -func NewLRUClientSessionCache(capacity int) ClientSessionCache { - const defaultSessionCacheCapacity = 64 - - if capacity < 1 { - capacity = defaultSessionCacheCapacity - } - return &lruSessionCache{ - m: make(map[string]*list.Element), - q: list.New(), - capacity: capacity, - } -} - -// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry -// corresponding to sessionKey is removed from the cache instead. -func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - if cs == nil { - c.q.Remove(elem) - delete(c.m, sessionKey) - } else { - entry := elem.Value.(*lruSessionCacheEntry) - entry.state = cs - c.q.MoveToFront(elem) - } - return - } - - if c.q.Len() < c.capacity { - entry := &lruSessionCacheEntry{sessionKey, cs} - c.m[sessionKey] = c.q.PushFront(entry) - return - } - - elem := c.q.Back() - entry := elem.Value.(*lruSessionCacheEntry) - delete(c.m, entry.sessionKey) - entry.sessionKey = sessionKey - entry.state = cs - c.q.MoveToFront(elem) - c.m[sessionKey] = elem -} - -// Get returns the ClientSessionState value associated with a given key. It -// returns (nil, false) if no value is found. -func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - c.q.MoveToFront(elem) - return elem.Value.(*lruSessionCacheEntry).state, true - } - return nil, false -} - -var emptyConfig Config - -func defaultConfig() *Config { - return &emptyConfig -} - -func unexpectedMessageError(wanted, got any) error { - return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) -} - -func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { - for _, s := range supportedSignatureAlgorithms { - if s == sigAlg { - return true - } - } - return false -} - -// CertificateVerificationError is returned when certificate verification fails during the handshake. -type CertificateVerificationError struct { - // UnverifiedCertificates and its contents should not be modified. - UnverifiedCertificates []*x509.Certificate - Err error -} - -func (e *CertificateVerificationError) Error() string { - return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err) -} - -func (e *CertificateVerificationError) Unwrap() error { - return e.Err -} diff --git a/pkg/tls/common_string.go b/pkg/tls/common_string.go deleted file mode 100644 index 238108811..000000000 --- a/pkg/tls/common_string.go +++ /dev/null @@ -1,116 +0,0 @@ -// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT. - -package tls - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[PKCS1WithSHA256-1025] - _ = x[PKCS1WithSHA384-1281] - _ = x[PKCS1WithSHA512-1537] - _ = x[PSSWithSHA256-2052] - _ = x[PSSWithSHA384-2053] - _ = x[PSSWithSHA512-2054] - _ = x[ECDSAWithP256AndSHA256-1027] - _ = x[ECDSAWithP384AndSHA384-1283] - _ = x[ECDSAWithP521AndSHA512-1539] - _ = x[Ed25519-2055] - _ = x[PKCS1WithSHA1-513] - _ = x[ECDSAWithSHA1-515] -} - -const ( - _SignatureScheme_name_0 = "PKCS1WithSHA1" - _SignatureScheme_name_1 = "ECDSAWithSHA1" - _SignatureScheme_name_2 = "PKCS1WithSHA256" - _SignatureScheme_name_3 = "ECDSAWithP256AndSHA256" - _SignatureScheme_name_4 = "PKCS1WithSHA384" - _SignatureScheme_name_5 = "ECDSAWithP384AndSHA384" - _SignatureScheme_name_6 = "PKCS1WithSHA512" - _SignatureScheme_name_7 = "ECDSAWithP521AndSHA512" - _SignatureScheme_name_8 = "PSSWithSHA256PSSWithSHA384PSSWithSHA512Ed25519" -) - -var ( - _SignatureScheme_index_8 = [...]uint8{0, 13, 26, 39, 46} -) - -func (i SignatureScheme) String() string { - switch { - case i == 513: - return _SignatureScheme_name_0 - case i == 515: - return _SignatureScheme_name_1 - case i == 1025: - return _SignatureScheme_name_2 - case i == 1027: - return _SignatureScheme_name_3 - case i == 1281: - return _SignatureScheme_name_4 - case i == 1283: - return _SignatureScheme_name_5 - case i == 1537: - return _SignatureScheme_name_6 - case i == 1539: - return _SignatureScheme_name_7 - case 2052 <= i && i <= 2055: - i -= 2052 - return _SignatureScheme_name_8[_SignatureScheme_index_8[i]:_SignatureScheme_index_8[i+1]] - default: - return "SignatureScheme(" + strconv.FormatInt(int64(i), 10) + ")" - } -} -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[CurveP256-23] - _ = x[CurveP384-24] - _ = x[CurveP521-25] - _ = x[X25519-29] -} - -const ( - _CurveID_name_0 = "CurveP256CurveP384CurveP521" - _CurveID_name_1 = "X25519" -) - -var ( - _CurveID_index_0 = [...]uint8{0, 9, 18, 27} -) - -func (i CurveID) String() string { - switch { - case 23 <= i && i <= 25: - i -= 23 - return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]] - case i == 29: - return _CurveID_name_1 - default: - return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")" - } -} -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[NoClientCert-0] - _ = x[RequestClientCert-1] - _ = x[RequireAnyClientCert-2] - _ = x[VerifyClientCertIfGiven-3] - _ = x[RequireAndVerifyClientCert-4] -} - -const _ClientAuthType_name = "NoClientCertRequestClientCertRequireAnyClientCertVerifyClientCertIfGivenRequireAndVerifyClientCert" - -var _ClientAuthType_index = [...]uint8{0, 12, 29, 49, 72, 98} - -func (i ClientAuthType) String() string { - if i < 0 || i >= ClientAuthType(len(_ClientAuthType_index)-1) { - return "ClientAuthType(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _ClientAuthType_name[_ClientAuthType_index[i]:_ClientAuthType_index[i+1]] -} diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go deleted file mode 100644 index 9e1398c11..000000000 --- a/pkg/tls/conn.go +++ /dev/null @@ -1,1546 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// TLS low level connection and record layer - -package tls - -import ( - "bytes" - "context" - "crypto/cipher" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "sync" - "time" -) - -// Socket is a set of functions which manipulate the underlying file descriptor of a connection. -type Socket interface { - // Fd returns the underlying file descriptor. - Fd() int - WriteTCP([]byte) (int, error) -} - -// A Conn represents a secured connection. -// It implements the net.Conn interface. -type Conn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - - // handshakeStatus is 1 if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // This field is only to be accessed with sync/atomic. - handshakeStatus uint8 - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *Config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - handshakes int - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []*activeCert - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // resumptionSecret is the resumption_master_secret for handling - // NewSessionTicket messages. nil if config.SessionTicketsDisabled. - resumptionSecret []byte - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []ticketKey - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - // By using the elastic MsgBuffer the tls conn not longer holds the actual buffer when the connection is idle. - // This can significantly optimize the memory usage, especially when the server connecting millions of clients - // where most of them are idle. - in, out halfConn - rawInput LazyBuffer // raw input, starting with a record header - input bytes.Reader // a buffer for decrypted records pointer to the inboundBuffer of gnet.conn - data []byte // buffer to hold all decrypted data - hand LazyBuffer // handshake data waiting to be read - buffering bool // whether records are buffered in sendBuf - sendBuf LazyBuffer // a buffer for records waiting to be sent also point to the outboundBuffer of gnet.conn - - // bytesSent counts the bytes of application data sent. - // packetsSent counts packets. - bytesSent int64 - packetsSent int64 - - // retryCount counts the number of consecutive non-advancing records - // received by Conn.readRecord. That is, records that neither advance the - // handshake, nor deliver application data. Protected by in.Mutex. - retryCount int - - // activeCall indicates whether Close has been call in the low bit. - // the rest of the bits are the number of goroutines in Conn.Write. - // activeCall atomic.Int32 - - tmp [16]byte - hs interface { - handshake() error - } -} - -// Access to net.Conn methods. -// Cannot just embed net.Conn because that would -// export the struct field too. - -// LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline sets the read and write deadlines associated with the connection. -// A zero value for t means Read and Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying connection. -// A zero value for t means Read will not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying connection. -// A zero value for t means Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -// NetConn returns the underlying connection that is wrapped by c. -// Note that writing to or reading from this connection directly will corrupt the -// TLS session. -func (c *Conn) NetConn() net.Conn { - return c.conn -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher interface{} // next encryption state - nextMac hash.Hash // next MAC algorithm - - trafficSecret []byte // current TLS 1.3 traffic secret - - key []byte // encrypt or decrypt key for kernel tls - iv []byte // encrypt or decrypt iv for kernel tls -} - -type permanentError struct { - err net.Error -} - -func (e *permanentError) Error() string { return e.err.Error() } -func (e *permanentError) Unwrap() error { return e.err } -func (e *permanentError) Timeout() bool { return e.err.Timeout() } -func (e *permanentError) Temporary() bool { return false } - -func (hc *halfConn) setErrorLocked(err error) error { - if e, ok := err.(net.Error); ok { - hc.err = &permanentError{err: e} - } else { - hc.err = err - } - return hc.err -} - -// prepareCipherSpec sets the encryption and MAC states -// that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { - hc.version = version - hc.nextCipher = cipher - hc.nextMac = mac -} - -// changeCipherSpec changes the encryption and MAC states -// to the ones previously passed to prepareCipherSpec. -func (hc *halfConn) changeCipherSpec() error { - if hc.nextCipher == nil || hc.version == VersionTLS13 { - return alertInternalError - } - hc.cipher = hc.nextCipher - hc.mac = hc.nextMac - hc.nextCipher = nil - hc.nextMac = nil - for i := range hc.seq { - hc.seq[i] = 0 - } - return nil -} - -func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { - hc.trafficSecret = secret - hc.key, hc.iv = suite.trafficKey(secret) - hc.cipher = suite.aead(hc.key, hc.iv) - for i := range hc.seq { - hc.seq[i] = 0 - } -} - -// incSeq increments the sequence number. -func (hc *halfConn) incSeq() { - for i := 7; i >= 0; i-- { - hc.seq[i]++ - if hc.seq[i] != 0 { - return - } - } - - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. - panic("TLS: sequence number wraparound") -} - -// explicitNonceLen returns the number of bytes of explicit nonce or IV included -// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 -// and in certain AEAD modes in TLS 1.2. -func (hc *halfConn) explicitNonceLen() int { - if hc.cipher == nil { - return 0 - } - - switch c := hc.cipher.(type) { - case cipher.Stream: - return 0 - case aead: - return c.explicitNonceLen() - case cbcMode: - // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. - if hc.version >= VersionTLS11 { - return c.BlockSize() - } - return 0 - // never reached, thus dead code - // case kTLSCipher: - // return 0 - default: - panic("unknown cipher type") - } -} - -// extractPadding returns, in constant time, the length of the padding to remove -// from the end of payload. It also returns a byte which is equal to 255 if the -// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. -func extractPadding(payload []byte) (toRemove int, good byte) { - if len(payload) < 1 { - return 0, 0 - } - - paddingLen := payload[len(payload)-1] - t := uint(len(payload)-1) - uint(paddingLen) - // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good = byte(int32(^t) >> 31) - - // The maximum possible padding length plus the actual length field - toCheck := 256 - // The length of the padded data is public, so we can use an if here - if toCheck > len(payload) { - toCheck = len(payload) - } - - for i := 0; i < toCheck; i++ { - t := uint(paddingLen) - uint(i) - // if i <= paddingLen then the MSB of t is zero - mask := byte(int32(^t) >> 31) - b := payload[len(payload)-1-i] - good &^= mask&paddingLen ^ mask&b - } - - // We AND together the bits of good and replicate the result across - // all the bits. - good &= good << 4 - good &= good << 2 - good &= good << 1 - good = uint8(int8(good) >> 7) - - // Zero the padding length on error. This ensures any unchecked bytes - // are included in the MAC. Otherwise, an attacker that could - // distinguish MAC failures from padding failures could mount an attack - // similar to POODLE in SSL 3.0: given a good ciphertext that uses a - // full block's worth of padding, replace the final block with another - // block. If the MAC check passed but the padding check failed, the - // last byte of that block decrypted to the block size. - // - // See also macAndPaddingGood logic below. - paddingLen &= good - - toRemove = int(paddingLen) + 1 - return -} - -func roundUp(a, b int) int { - return a + (b-a%b)%b -} - -// cbcMode is an interface for block ciphers using cipher block chaining. -type cbcMode interface { - cipher.BlockMode - SetIV([]byte) -} - -// decrypt authenticates and decrypts the record if protection is active at -// this stage. The returned plaintext might overlap with the input. -func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { - var plaintext []byte - typ := recordType(record[0]) - payload := record[recordHeaderLen:] - - // In TLS 1.3, change_cipher_spec messages are to be ignored without being - // decrypted. See RFC 8446, Appendix D.4. - if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { - return payload, typ, nil - } - - paddingGood := byte(255) - paddingLen := 0 - - explicitNonceLen := hc.explicitNonceLen() - - if hc.cipher != nil { - switch c := hc.cipher.(type) { - case cipher.Stream: - c.XORKeyStream(payload, payload) - case aead: - if len(payload) < explicitNonceLen { - return nil, 0, alertBadRecordMAC - } - nonce := payload[:explicitNonceLen] - if len(nonce) == 0 { - nonce = hc.seq[:] - } - payload = payload[explicitNonceLen:] - - var additionalData []byte - if hc.version == VersionTLS13 { - additionalData = record[:recordHeaderLen] - } else { - additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:3]...) - n := len(payload) - c.Overhead() - additionalData = append(additionalData, byte(n>>8), byte(n)) - } - - var err error - plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) - if err != nil { - return nil, 0, alertBadRecordMAC - } - case cbcMode: - blockSize := c.BlockSize() - minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) - if len(payload)%blockSize != 0 || len(payload) < minPayload { - return nil, 0, alertBadRecordMAC - } - - if explicitNonceLen > 0 { - c.SetIV(payload[:explicitNonceLen]) - payload = payload[explicitNonceLen:] - } - c.CryptBlocks(payload, payload) - - // In a limited attempt to protect against CBC padding oracles like - // Lucky13, the data past paddingLen (which is secret) is passed to - // the MAC function as extra data, to be fed into the HMAC after - // computing the digest. This makes the MAC roughly constant time as - // long as the digest computation is constant time and does not - // affect the subsequent write, modulo cache effects. - paddingLen, paddingGood = extractPadding(payload) - default: - panic("unknown cipher type") - } - - if hc.version == VersionTLS13 { - if typ != recordTypeApplicationData { - return nil, 0, alertUnexpectedMessage - } - if len(plaintext) > maxPlaintext+1 { - return nil, 0, alertRecordOverflow - } - // Remove padding and find the ContentType scanning from the end. - for i := len(plaintext) - 1; i >= 0; i-- { - if plaintext[i] != 0 { - typ = recordType(plaintext[i]) - plaintext = plaintext[:i] - break - } - if i == 0 { - return nil, 0, alertUnexpectedMessage - } - } - } - } else { - plaintext = payload - } - - if hc.mac != nil { - macSize := hc.mac.Size() - if len(payload) < macSize { - return nil, 0, alertBadRecordMAC - } - - n := len(payload) - macSize - paddingLen - n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } - record[3] = byte(n >> 8) - record[4] = byte(n) - remoteMAC := payload[n : n+macSize] - localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) - - // This is equivalent to checking the MACs and paddingGood - // separately, but in constant-time to prevent distinguishing - // padding failures from MAC failures. Depending on what value - // of paddingLen was returned on bad padding, distinguishing - // bad MAC from bad padding can lead to an attack. - // - // See also the logic at the end of extractPadding. - macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) - if macAndPaddingGood != 1 { - return nil, 0, alertBadRecordMAC - } - - plaintext = payload[:n] - } - - hc.incSeq() - return plaintext, typ, nil -} - -// sliceForAppend extends the input slice by n bytes. head is the full extended -// slice, while tail is the appended part. If the original slice has sufficient -// capacity no allocation is performed. -func sliceForAppend(in []byte, n int) (head, tail []byte) { - if total := len(in) + n; cap(in) >= total { - head = in[:total] - } else { - head = make([]byte, total) - copy(head, in) - } - tail = head[len(in):] - return -} - -// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which must already contain the record header. -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - if hc.cipher == nil { - return append(record, payload...), nil - } - - var explicitNonce []byte - if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { - record, explicitNonce = sliceForAppend(record, explicitNonceLen) - if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { - // The AES-GCM construction in TLS has an explicit nonce so that the - // nonce can be random. However, the nonce is only 8 bytes which is - // too small for a secure, random nonce. Therefore we use the - // sequence number as the nonce. The 3DES-CBC construction also has - // an 8 bytes nonce but its nonces must be unpredictable (see RFC - // 5246, Appendix F.3), forcing us to use randomness. That's not - // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its similarly small block size - // (see the Sweet32 attack). - copy(explicitNonce, hc.seq[:]) - } else { - if _, err := io.ReadFull(rand, explicitNonce); err != nil { - return nil, err - } - } - } - - var dst []byte - switch c := hc.cipher.(type) { - case cipher.Stream: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - record, dst = sliceForAppend(record, len(payload)+len(mac)) - c.XORKeyStream(dst[:len(payload)], payload) - c.XORKeyStream(dst[len(payload):], mac) - case aead: - nonce := explicitNonce - if len(nonce) == 0 { - nonce = hc.seq[:] - } - - if hc.version == VersionTLS13 { - record = append(record, payload...) - - // Encrypt the actual ContentType and replace the plaintext one. - record = append(record, record[0]) - record[0] = byte(recordTypeApplicationData) - - n := len(payload) + 1 + c.Overhead() - record[3] = byte(n >> 8) - record[4] = byte(n) - - record = c.Seal(record[:recordHeaderLen], - nonce, record[recordHeaderLen:], record[:recordHeaderLen]) - } else { - additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:recordHeaderLen]...) - record = c.Seal(record, nonce, payload, additionalData) - } - case cbcMode: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - blockSize := c.BlockSize() - plaintextLen := len(payload) + len(mac) - paddingLen := blockSize - plaintextLen%blockSize - record, dst = sliceForAppend(record, plaintextLen+paddingLen) - copy(dst, payload) - copy(dst[len(payload):], mac) - for i := plaintextLen; i < len(dst); i++ { - dst[i] = byte(paddingLen - 1) - } - if len(explicitNonce) > 0 { - c.SetIV(explicitNonce) - } - c.CryptBlocks(dst, dst) - default: - panic("unknown cipher type") - } - - // Update length to include nonce, MAC and any block padding needed. - n := len(record) - recordHeaderLen - record[3] = byte(n >> 8) - record[4] = byte(n) - hc.incSeq() - - return record, nil -} - -// RecordHeaderError is returned when a TLS record header is invalid. -type RecordHeaderError struct { - // Msg contains a human readable string that describes the error. - Msg string - // RecordHeader contains the five bytes of TLS record header that - // triggered the error. - RecordHeader [5]byte - // Conn provides the underlying net.Conn in the case that a client - // sent an initial handshake that didn't look like TLS. - // It is nil if there's already been a handshake or a TLS alert has - // been written to the connection. - Conn net.Conn -} - -func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } - -func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { - err.Msg = msg - err.Conn = conn - copy(err.RecordHeader[:], c.rawInput.Bytes()) - return err -} - -func (c *Conn) readRecord() error { - return c.readRecordOrCCS(false) -} - -func (c *Conn) readChangeCipherSpec() error { - return c.readRecordOrCCS(true) - -} - -// ktlsInBufPool pools the buffers used by ktlsReadRecord. -var ktlsInBufPool = sync.Pool{ - New: func() any { - buf := make([]byte, maxPlaintext) - return &buf - }, -} - -// readRecordOrCCS reads one or more TLS records from the connection and -// updates the record layer state. Some invariants: -// - c.in must be locked -// - c.input must be empty -// -// During the handshake one and only one of the following will happen: -// - c.hand grows -// - c.in.changeCipherSpec is called -// - an error is returned -// -// After the handshake one and only one of the following will happen: -// - c.hand grows -// - c.input is set -// - an error is returned -func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { - if c.in.err != nil { - return c.in.err - } - handshakeComplete := c.HandshakeComplete() - - var ( - typ recordType - data []byte - record []byte - hdr []byte - n int - vers uint16 - err error - ) - - if _, ok := c.in.cipher.(kTLSCipher); ok { - if c.rawInput.Len() < maxPlaintext { - c.rawInput.Extend(maxPlaintext - c.rawInput.Len()) - } - data = c.rawInput.Bytes()[:maxPlaintext] - if typ, n, err = ktlsReadRecord(c.conn.(Socket).Fd(), data); err != nil { - return err - } - data = c.rawInput.Next(n) - // TODO: process the data here instead of goto processMessage - // && try to use ktlsReadRecord to write data directly into input - // rather than copy it later. - goto processMessage - } - - if c.rawInput.Len() <= recordHeaderLen { - return io.EOF - } - - hdr = c.rawInput.Bytes() - typ = recordType(hdr[0]) - - // No valid TLS record has a type of 0x80, however SSLv2 handshakes - // start with a uint16 length where the MSB is set and the first record - // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests - // an SSLv2 client. - if !handshakeComplete && typ == 0x80 { - c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) - } - - vers = uint16(hdr[1])<<8 | uint16(hdr[2]) - n = int(hdr[3])<<8 | int(hdr[4]) - - if len(hdr) < recordHeaderLen+n { - return io.EOF - } - - if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { - c.sendAlert(alertProtocolVersion) - msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if !c.haveVers { - // First message, be extra suspicious: this might not be a TLS - // client. Bail out before reading a full 'body', if possible. - // The current max version is 3.3 so if the version is >= 16.0, - // it's probably not real. - if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { - return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) - } - } - if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { - c.sendAlert(alertRecordOverflow) - msg := fmt.Sprintf("oversized record received with length %d", n) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - - // Process message. - record = c.rawInput.Next(recordHeaderLen + n) - data, typ, err = c.in.decrypt(record) - if err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - -processMessage: - if len(data) > maxPlaintext { - return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) - } - - // Application Data messages are always protected. - if c.in.cipher == nil && typ == recordTypeApplicationData { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { - // This is a state-advancing message: reset the retry count. - c.retryCount = 0 - } - - // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. - if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - switch typ { - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - - case recordTypeAlert: - if len(data) != 2 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if alert(data[1]) == alertCloseNotify { - return c.in.setErrorLocked(io.EOF) - } - if c.vers == VersionTLS13 { - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - } - switch data[0] { - case alertLevelWarning: - // Drop the record on the floor and retry. - return c.retryReadRecord(expectChangeCipherSpec) - case alertLevelError: - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - case recordTypeChangeCipherSpec: - if len(data) != 1 || data[0] != 1 { - return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) - } - // Handshake messages are not allowed to fragment across the CCS. - if c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // In TLS 1.3, change_cipher_spec records are ignored until the - // Finished. See RFC 8446, Appendix D.4. Note that according to Section - // 5, a server can send a ChangeCipherSpec before its ServerHello, when - // c.vers is still unset. That's not useful though and suspicious if the - // server then selects a lower protocol version, so don't allow that. - if c.vers == VersionTLS13 { - return c.retryReadRecord(expectChangeCipherSpec) - } - if !expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if err := c.in.changeCipherSpec(); err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - - case recordTypeApplicationData: - if !handshakeComplete || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // Some OpenSSL servers send empty records in order to randomize the - // CBC IV. Ignore a limited number of empty records. - if len(data) == 0 { - return c.retryReadRecord(expectChangeCipherSpec) - } - // Note that data is owned by c.rawInput, following the Next call above, - // to avoid copying the plaintext. This is safe because c.rawInput is - // not read from or written to until c.input is drained. - c.input.Reset(data) - c.data = data - - case recordTypeHandshake: - if len(data) == 0 || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - c.hand.Set(data) - } - - return nil -} - -// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like -// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. -func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many ignored records")) - } - return c.readRecordOrCCS(expectChangeCipherSpec) -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlertLocked(err alert) error { - switch err { - case alertNoRenegotiation, alertCloseNotify: - c.tmp[0] = alertLevelWarning - default: - c.tmp[0] = alertLevelError - } - c.tmp[1] = byte(err) - - _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) - if err == alertCloseNotify { - // closeNotify is a special case in that it isn't an error. - return writeErr - } - - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlert(err alert) error { - return c.sendAlertLocked(err) -} - -const ( - // tcpMSSEstimate is a conservative estimate of the TCP maximum segment - // size (MSS). A constant is used, rather than querying the kernel for - // the actual MSS, to avoid complexity. The value here is the IPv6 - // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 - // bytes) and a TCP header with timestamps (32 bytes). - tcpMSSEstimate = 1208 - - // recordSizeBoostThreshold is the number of bytes of application data - // sent after which the TLS record size will be increased to the - // maximum. - recordSizeBoostThreshold = 128 * 1024 -) - -// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the -// next application data record. There is the following trade-off: -// -// - For latency-sensitive applications, such as web browsing, each TLS -// record should fit in one TCP segment. -// - For throughput-sensitive applications, such as large file transfers, -// larger TLS records better amortize framing and encryption overheads. -// -// A simple heuristic that works well in practice is to use small records for -// the first 1MB of data, then use larger records for subsequent data, and -// reset back to smaller records after the connection becomes idle. See "High -// Performance Web Networking", Chapter 4, or: -// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ -// -// In the interests of simplicity and determinism, this code does not attempt -// to reset the record size once the connection is idle, however. -func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { - if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { - return maxPlaintext - } - - if c.bytesSent >= recordSizeBoostThreshold { - return maxPlaintext - } - - // Subtract TLS overheads to get the maximum payload size. - payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() - if c.out.cipher != nil { - switch ciph := c.out.cipher.(type) { - case cipher.Stream: - payloadBytes -= c.out.mac.Size() - case cipher.AEAD: - payloadBytes -= ciph.Overhead() - case cbcMode: - blockSize := ciph.BlockSize() - // The payload must fit in a multiple of blockSize, with - // room for at least one padding byte. - payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 - // The MAC is appended before padding so affects the - // payload size directly. - payloadBytes -= c.out.mac.Size() - // never reached, thus dead code - // case kTLSCipher: - // payloadBytes -= kTLSOverhead - default: - panic("unknown cipher type") - } - } - if c.vers == VersionTLS13 { - payloadBytes-- // encrypted ContentType - } - - // Allow packet growth in arithmetic progression up to max. - pkt := c.packetsSent - c.packetsSent++ - if pkt > 1000 { - return maxPlaintext // avoid overflow in multiply below - } - - n := payloadBytes * int(pkt+1) - if n > maxPlaintext { - n = maxPlaintext - } - return n -} - -func (c *Conn) write(data []byte) (int, error) { - if c.buffering { - _, _ = c.sendBuf.Write(data) - return len(data), nil - } - - n, err := c.conn.(Socket).WriteTCP(data) - c.bytesSent += int64(n) - return n, err -} - -func (c *Conn) flush() (int, error) { - if c.sendBuf.Len() == 0 { - return 0, nil - } - n, err := c.conn.(Socket).WriteTCP(c.sendBuf.Bytes()) - c.bytesSent += int64(n) - c.sendBuf.Done() - c.buffering = false - return n, err -} - -// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. -var outBufPool = sync.Pool{ - New: func() any { - return new([]byte) - }, -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { - if _, ok := c.out.cipher.(kTLSCipher); ok { - switch typ { - case recordTypeAlert: - return ktlsSendCtrlMessage(c.conn.(Socket).Fd(), typ, data) - case recordTypeHandshake, recordTypeChangeCipherSpec: - return ktlsSendCtrlMessage(c.conn.(Socket).Fd(), typ, data) - case recordTypeApplicationData: - return c.write(data) - default: - panic("unknown record type") - } - } - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - for len(data) > 0 { - m := len(data) - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = VersionTLS10 - } else if vers == VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - } - - if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(err.(alert)) - } - } - - return n, nil -} - -// writeHandshakeRecord writes a handshake message to the connection and updates -// the record layer state. If transcript is non-nil the marshalled message is -// written to it. -func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { - data, err := msg.marshal() - if err != nil { - return 0, err - } - if transcript != nil { - transcript.Write(data) - } - - return c.writeRecordLocked(recordTypeHandshake, data) -} - -// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and -// updates the record layer state. -func (c *Conn) writeChangeCipherRecord() error { - _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) - return err -} - -// readHandshake reads the next handshake message from -// the record layer. If transcript is non-nil, the message -// is written to the passed transcriptHash. -func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { - for c.hand.Len() < 4 { - if err := c.readRecord(); err != nil { - return nil, err - } - } - - data := c.hand.Bytes() - n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if n > maxHandshake { - c.sendAlertLocked(alertInternalError) - return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) - } - for c.hand.Len() < 4+n { - if err := c.readRecord(); err != nil { - return nil, err - } - } - data = c.hand.Next(4 + n) - // Handshake messages are all processed, return the buffer to the pool - defer c.hand.Done() - var m handshakeMessage - switch data[0] { - case typeHelloRequest: - m = new(helloRequestMsg) - case typeClientHello: - m = new(clientHelloMsg) - case typeServerHello: - m = new(serverHelloMsg) - case typeNewSessionTicket: - if c.vers == VersionTLS13 { - m = new(newSessionTicketMsgTLS13) - } else { - m = new(newSessionTicketMsg) - } - case typeCertificate: - if c.vers == VersionTLS13 { - m = new(certificateMsgTLS13) - } else { - m = new(certificateMsg) - } - case typeCertificateRequest: - if c.vers == VersionTLS13 { - m = new(certificateRequestMsgTLS13) - } else { - m = &certificateRequestMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - } - case typeCertificateStatus: - m = new(certificateStatusMsg) - case typeServerKeyExchange: - m = new(serverKeyExchangeMsg) - case typeServerHelloDone: - m = new(serverHelloDoneMsg) - case typeClientKeyExchange: - m = new(clientKeyExchangeMsg) - case typeCertificateVerify: - m = &certificateVerifyMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - case typeFinished: - m = new(finishedMsg) - case typeEncryptedExtensions: - m = new(encryptedExtensionsMsg) - case typeEndOfEarlyData: - m = new(endOfEarlyDataMsg) - case typeKeyUpdate: - m = new(keyUpdateMsg) - default: - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - // The handshake message unmarshalers - // expect to be able to keep references to data, - // so pass in a fresh copy that won't be overwritten. - data = append([]byte(nil), data...) - - if !m.unmarshal(data) { - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - if transcript != nil { - transcript.Write(data) - } - - return m, nil -} - -var ( - errShutdown = errors.New("tls: protocol is shutdown") -) - -// Write writes data to the connection. -// -// As Write calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Write is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Write(b []byte) (int, error) { - // interlock with Close below - - if !c.HandshakeComplete() { - return 0, nil - } - - // c.buffering = false - - if err := c.out.err; err != nil { - return 0, err - } - - if c.closeNotifySent { - return 0, errShutdown - } - - // TLS 1.0 is susceptible to a chosen-plaintext - // attack when using block mode ciphers due to predictable IVs. - // This can be prevented by splitting each Application Data - // record into two records, effectively randomizing the IV. - // - // https://www.openssl.org/~bodo/tls-cbc.txt - // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 - // https://www.imperialviolet.org/2012/01/15/beastfollowup.html - - var m int - if len(b) > 1 && c.vers == VersionTLS10 { - if _, ok := c.out.cipher.(cipher.BlockMode); ok { - n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) - if err != nil { - return n, c.out.setErrorLocked(err) - } - m, b = 1, b[1:] - } - } - - n, err := c.writeRecordLocked(recordTypeApplicationData, b) - return n + m, c.out.setErrorLocked(err) -} - -// check whether the data is a complete TLS record -func (c *Conn) IsRecordCompleted(data []byte) bool { - if len(data) < recordHeaderLen { - return false - } - if len(data) < recordHeaderLen+int(data[3])<<8|int(data[4]) { - return false - } - return true -} - -// load the data into the TLS rawInput -// If rawInput is lazy and empty, the data is loaded immediately -// as a reference (zero-copy) -func (c *Conn) RawInputSet(data []byte) (int, error) { - c.rawInput.Set(data) - return len(data), nil -} - -// Decrypt one tls record and save it in c.input which is -// owned by c.rawInput -func (c *Conn) ReadFrame() error { - return c.readRecordOrCCS(false) -} - -// return all rawInput data -func (c *Conn) RawInputData() []byte { - return c.rawInput.Bytes() -} - -// Clean up all decrypted data -// rawInput is cleaned up if all rawInput are processed. -// otherwise raw data is cached in order to make sure the -// is not owned by anyone else -func (c *Conn) DataDone() { - if c.rawInput.Len() == 0 { - // raw input is drain, thus clean it up - c.rawInput.Done() - } else { - // raw data has a few bytes left but not sufficient, so we cache it - // to make sure the data is not owned by anyone else - c.rawInput.Write(nil) - } - c.input.Reset(nil) - c.data = nil -} - -// call this function after close. so allocated memory for rawInput and hand -// are returned to the pool -func (c *Conn) DataCleanUpAfterClose() { - c.data = nil - c.input.Reset(nil) - c.rawInput.Done() -} - -// Return all decrypted data -func (c *Conn) Data() []byte { - return c.data -} - -// handleRenegotiation processes a HelloRequest handshake message. -func (c *Conn) handleRenegotiation() error { - if c.vers == VersionTLS13 { - return errors.New("tls: internal error: unexpected renegotiation") - } - - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - helloReq, ok := msg.(*helloRequestMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(helloReq, msg) - } - - if !c.isClient { - return c.sendAlert(alertNoRenegotiation) - } - - switch c.config.Renegotiation { - case RenegotiateNever: - return c.sendAlert(alertNoRenegotiation) - case RenegotiateOnceAsClient: - if c.handshakes > 1 { - return c.sendAlert(alertNoRenegotiation) - } - case RenegotiateFreelyAsClient: - // Ok. - default: - c.sendAlert(alertInternalError) - return errors.New("tls: unknown Renegotiation value") - } - - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - c.handshakeStatus = 0 - if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { - c.handshakes++ - } - return c.handshakeErr -} - -// handlePostHandshakeMessage processes a handshake message arrived after the -// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. -func (c *Conn) handlePostHandshakeMessage() error { - if c.vers != VersionTLS13 { - return c.handleRenegotiation() - } - - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) - } - - switch msg := msg.(type) { - case *newSessionTicketMsgTLS13: - return c.handleNewSessionTicket(msg) - case *keyUpdateMsg: - return c.handleKeyUpdate(msg) - default: - c.sendAlert(alertUnexpectedMessage) - return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) - } -} - -func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil { - return c.in.setErrorLocked(c.sendAlert(alertInternalError)) - } - - newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - c.in.setTrafficSecret(cipherSuite, newSecret) - - if keyUpdate.updateRequested { - - msg := &keyUpdateMsg{} - msgBytes, err := msg.marshal() - if err != nil { - return err - } - _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) - if err != nil { - // Surface the error at the next write. - c.out.setErrorLocked(err) - return nil - } - - newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) - c.out.setTrafficSecret(cipherSuite, newSecret) - } - - return nil -} - -// Read reads data from the connection. -// -// As Read calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Read is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Read(b []byte) (int, error) { - if !c.HandshakeComplete() { - return 0, nil - } - if len(b) == 0 { - // Put this after Handshake, in case people were calling - // Read(nil) for the side effect of the Handshake. - return 0, nil - } - - for c.input.Len() == 0 { - if err := c.readRecord(); err != nil { - return 0, err - } - for c.hand.Len() > 0 { - if err := c.handlePostHandshakeMessage(); err != nil { - return 0, err - } - } - } - - n, _ := c.input.Read(b) - - // If a close-notify alert is waiting, read it so that we can return (n, - // EOF) instead of (n, nil), to signal to the HTTP response reading - // goroutine that the connection is now closed. This eliminates a race - // where the HTTP response reading goroutine would otherwise not observe - // the EOF until its next read, by which time a client goroutine might - // have already tried to reuse the HTTP connection for a new request. - // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 - if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && - recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { - if err := c.readRecord(); err != nil { - return n, err // will be io.EOF on closeNotify - } - } - - return n, nil -} - -// Close closes the connection. -func (c *Conn) Close() error { - var alertErr error - if c.HandshakeComplete() { - if err := c.closeNotify(); err != nil { - alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) - } - } - return alertErr -} - -var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") - -// CloseWrite shuts down the writing side of the connection. It should only be -// called once the handshake has completed and does not call CloseWrite on the -// underlying connection. Most callers should just use Close. -func (c *Conn) CloseWrite() error { - if !c.HandshakeComplete() { - return errEarlyCloseWrite - } - - return c.closeNotify() -} - -func (c *Conn) closeNotify() error { - if !c.closeNotifySent { - c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) - c.closeNotifySent = true - } - return c.closeNotifyErr -} - -// Handshake runs the client or server handshake -// protocol if it has not yet been run. -// -// Most uses of this package need not call Handshake explicitly: the -// first Read or Write will call it automatically. -// -// For control over canceling or setting a timeout on a handshake, use -// HandshakeContext or the Dialer's DialContext method instead. -func (c *Conn) Handshake() error { - return c.HandshakeContext(context.Background()) -} - -// HandshakeContext runs the client or server handshake -// protocol if it has not yet been run. -// -// The provided Context must be non-nil. If the context is canceled before -// the handshake is complete, the handshake is interrupted and an error is returned. -// Once the handshake has completed, cancellation of the context will not affect the -// connection. -// -// Most uses of this package need not call HandshakeContext explicitly: the -// first Read or Write will call it automatically. -func (c *Conn) HandshakeContext(ctx context.Context) error { - // Delegate to unexported method for named return - // without confusing documented signature. - return c.handshakeContext(ctx) -} - -func (c *Conn) handshakeContext(ctx context.Context) (ret error) { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if err := c.handshakeErr; err != nil { - return err - } - if c.HandshakeComplete() { - return nil - } - - c.handshakeErr = c.handshakeFn(ctx) - if c.handshakeErr == io.EOF { - c.handshakeErr = nil - } - if c.handshakeErr == nil { - c.handshakes++ - } else { - // If an error occurred during the handshake try to flush the - // alert that might be left in the buffer. - c.flush() - } - - return c.handshakeErr -} - -// ConnectionState returns basic TLS details about the connection. -func (c *Conn) ConnectionState() ConnectionState { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return c.connectionStateLocked() -} - -func (c *Conn) connectionStateLocked() ConnectionState { - var state ConnectionState - state.HandshakeComplete = c.HandshakeComplete() - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = true - state.ServerName = c.serverName - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial - } else { - state.ekm = c.ekm - } - return state -} - -func (c *Conn) HandshakeComplete() bool { - return c.handshakeStatus == 255 -} - -// OCSPResponse returns the stapled OCSP response from the TLS server, if -// any. (Only valid for client connections.) -func (c *Conn) OCSPResponse() []byte { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - return c.ocspResponse -} - -// VerifyHostname checks that the peer certificate chain is valid for -// connecting to host. If so, it returns nil; if not, it returns an error -// describing the problem. -func (c *Conn) VerifyHostname(host string) error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - if !c.isClient { - return errors.New("tls: VerifyHostname called on TLS server connection") - } - if !c.HandshakeComplete() { - return errors.New("tls: handshake has not yet been performed") - } - if len(c.verifiedChains) == 0 { - return errors.New("tls: handshake did not verify certificate chain") - } - return c.peerCertificates[0].VerifyHostname(host) -} diff --git a/pkg/tls/generate_cert.go b/pkg/tls/generate_cert.go deleted file mode 100644 index cd4bfc513..000000000 --- a/pkg/tls/generate_cert.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build ignore - -// Generate a self-signed X.509 certificate for a TLS server. Outputs to -// 'cert.pem' and 'key.pem' and will overwrite existing files. - -package main - -import ( - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "flag" - "log" - "math/big" - "net" - "os" - "strings" - "time" -) - -var ( - host = flag.String("host", "", "Comma-separated hostnames and IPs to generate a certificate for") - validFrom = flag.String("start-date", "", "Creation date formatted as Jan 1 15:04:05 2011") - validFor = flag.Duration("duration", 365*24*time.Hour, "Duration that certificate is valid for") - isCA = flag.Bool("ca", false, "whether this cert should be its own Certificate Authority") - rsaBits = flag.Int("rsa-bits", 2048, "Size of RSA key to generate. Ignored if --ecdsa-curve is set") - ecdsaCurve = flag.String("ecdsa-curve", "", "ECDSA curve to use to generate a key. Valid values are P224, P256 (recommended), P384, P521") - ed25519Key = flag.Bool("ed25519", false, "Generate an Ed25519 key") -) - -func publicKey(priv any) any { - switch k := priv.(type) { - case *rsa.PrivateKey: - return &k.PublicKey - case *ecdsa.PrivateKey: - return &k.PublicKey - case ed25519.PrivateKey: - return k.Public().(ed25519.PublicKey) - default: - return nil - } -} - -func main() { - flag.Parse() - - if len(*host) == 0 { - log.Fatalf("Missing required --host parameter") - } - - var priv any - var err error - switch *ecdsaCurve { - case "": - if *ed25519Key { - _, priv, err = ed25519.GenerateKey(rand.Reader) - } else { - priv, err = rsa.GenerateKey(rand.Reader, *rsaBits) - } - case "P224": - priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) - case "P256": - priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - case "P384": - priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - case "P521": - priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - default: - log.Fatalf("Unrecognized elliptic curve: %q", *ecdsaCurve) - } - if err != nil { - log.Fatalf("Failed to generate private key: %v", err) - } - - // ECDSA, ED25519 and RSA subject keys should have the DigitalSignature - // KeyUsage bits set in the x509.Certificate template - keyUsage := x509.KeyUsageDigitalSignature - // Only RSA subject keys should have the KeyEncipherment KeyUsage bits set. In - // the context of TLS this KeyUsage is particular to RSA key exchange and - // authentication. - if _, isRSA := priv.(*rsa.PrivateKey); isRSA { - keyUsage |= x509.KeyUsageKeyEncipherment - } - - var notBefore time.Time - if len(*validFrom) == 0 { - notBefore = time.Now() - } else { - notBefore, err = time.Parse("Jan 2 15:04:05 2006", *validFrom) - if err != nil { - log.Fatalf("Failed to parse creation date: %v", err) - } - } - - notAfter := notBefore.Add(*validFor) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - log.Fatalf("Failed to generate serial number: %v", err) - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"Acme Co"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: keyUsage, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - hosts := strings.Split(*host, ",") - for _, h := range hosts { - if ip := net.ParseIP(h); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, h) - } - } - - if *isCA { - template.IsCA = true - template.KeyUsage |= x509.KeyUsageCertSign - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) - if err != nil { - log.Fatalf("Failed to create certificate: %v", err) - } - - certOut, err := os.Create("cert.pem") - if err != nil { - log.Fatalf("Failed to open cert.pem for writing: %v", err) - } - if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - log.Fatalf("Failed to write data to cert.pem: %v", err) - } - if err := certOut.Close(); err != nil { - log.Fatalf("Error closing cert.pem: %v", err) - } - log.Print("wrote cert.pem\n") - - keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - log.Fatalf("Failed to open key.pem for writing: %v", err) - } - privBytes, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - log.Fatalf("Unable to marshal private key: %v", err) - } - if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { - log.Fatalf("Failed to write data to key.pem: %v", err) - } - if err := keyOut.Close(); err != nil { - log.Fatalf("Error closing key.pem: %v", err) - } - log.Print("wrote key.pem\n") -} diff --git a/pkg/tls/go120.go b/pkg/tls/go120.go new file mode 100644 index 000000000..e01e88aa3 --- /dev/null +++ b/pkg/tls/go120.go @@ -0,0 +1,87 @@ +//go:build go1.20 + +package tls + +import ( + "net" + _ "unsafe" + + gtls "github.com/0-haha/gnet_go_tls/v120" +) + +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = gtls.TLS_RSA_WITH_RC4_128_SHA + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = gtls.TLS_RSA_WITH_3DES_EDE_CBC_SHA + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = gtls.TLS_RSA_WITH_AES_128_CBC_SHA + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = gtls.TLS_RSA_WITH_AES_256_CBC_SHA + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = gtls.TLS_RSA_WITH_AES_128_CBC_SHA256 + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = gtls.TLS_RSA_WITH_AES_128_GCM_SHA256 + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = gtls.TLS_RSA_WITH_AES_256_GCM_SHA384 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = gtls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = gtls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = gtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = gtls.TLS_ECDHE_RSA_WITH_RC4_128_SHA + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = gtls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = gtls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = gtls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = gtls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = gtls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = gtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = gtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = gtls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = gtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = gtls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = gtls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = gtls.TLS_AES_128_GCM_SHA256 + TLS_AES_256_GCM_SHA384 uint16 = gtls.TLS_AES_256_GCM_SHA384 + TLS_CHACHA20_POLY1305_SHA256 uint16 = gtls.TLS_CHACHA20_POLY1305_SHA256 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = gtls.TLS_FALLBACK_SCSV + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = gtls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = gtls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) + +type ( + // A Certificate is gnet_go_tls101.Certificate. + Certificate = gtls.Certificate + // CertificateRequestInfo contains information about a certificate request. + CertificateRequestInfo = gtls.CertificateRequestInfo + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = gtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = gtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = gtls.ClientSessionState + // A Config is a gnet_go_tls101.Config. + Config = gtls.Config + // A Conn is a gnet_go_tls101.Conn. + Conn = gtls.Conn +) + +const ( + VersionTLS10 = gtls.VersionTLS10 + VersionTLS11 = gtls.VersionTLS11 + VersionTLS12 = gtls.VersionTLS12 + VersionTLS13 = gtls.VersionTLS13 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = gtls.CurveP256 +) + +//go:linkname Server github.com/0-haha/gnet_go_tls/v120.Server +func Server(conn net.Conn, config *Config) *Conn + +//go:linkname LoadX509KeyPair github.com/0-haha/gnet_go_tls/v120.LoadX509KeyPair +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) + +//go:linkname X509KeyPair github.com/0-haha/gnet_go_tls/v120.X509KeyPair +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) \ No newline at end of file diff --git a/pkg/tls/go_oldversion.go b/pkg/tls/go_oldversion.go new file mode 100644 index 000000000..02128fb4d --- /dev/null +++ b/pkg/tls/go_oldversion.go @@ -0,0 +1,5 @@ +//go:build !go1.20 + +package tls + +var _ int = "The version of gnet you're using can't be built using outdated Go versions. For more details, please see [TODO: add links here]." \ No newline at end of file diff --git a/pkg/tls/handshake_client.go b/pkg/tls/handshake_client.go deleted file mode 100644 index 8d8633eb1..000000000 --- a/pkg/tls/handshake_client.go +++ /dev/null @@ -1,1074 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdh" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "strings" - "time" -) - -type clientHandshakeState struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - suite *cipherSuite - finishedHash finishedHash - masterSecret []byte - session *ClientSessionState - oldSession *ClientSessionState - cacheKey string -} - -var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme - -func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { - config := c.config - if len(config.ServerName) == 0 && !config.InsecureSkipVerify { - return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") - } - - nextProtosLength := 0 - for _, proto := range config.NextProtos { - if l := len(proto); l == 0 || l > 255 { - return nil, nil, errors.New("tls: invalid NextProtos value") - } else { - nextProtosLength += 1 + l - } - } - if nextProtosLength > 0xffff { - return nil, nil, errors.New("tls: NextProtos values too large") - } - - supportedVersions := config.supportedVersions(roleClient) - if len(supportedVersions) == 0 { - return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") - } - - clientHelloVersion := config.maxSupportedVersion(roleClient) - // The version at the beginning of the ClientHello was capped at TLS 1.2 - // for compatibility reasons. The supported_versions extension is used - // to negotiate versions now. See RFC 8446, Section 4.2.1. - if clientHelloVersion > VersionTLS12 { - clientHelloVersion = VersionTLS12 - } - - hello := &clientHelloMsg{ - vers: clientHelloVersion, - compressionMethods: []uint8{compressionNone}, - random: make([]byte, 32), - sessionId: make([]byte, 32), - ocspStapling: true, - scts: true, - serverName: hostnameInSNI(config.ServerName), - supportedCurves: config.curvePreferences(), - supportedPoints: []uint8{pointFormatUncompressed}, - secureRenegotiationSupported: true, - alpnProtocols: config.NextProtos, - supportedVersions: supportedVersions, - } - - if c.handshakes > 0 { - hello.secureRenegotiation = c.clientFinished[:] - } - - preferenceOrder := cipherSuitesPreferenceOrder - if !hasAESGCMHardwareSupport { - preferenceOrder = cipherSuitesPreferenceOrderNoAES - } - configCipherSuites := config.cipherSuites() - hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) - - for _, suiteId := range preferenceOrder { - suite := mutualCipherSuite(configCipherSuites, suiteId) - if suite == nil { - continue - } - // Don't advertise TLS 1.2-only cipher suites unless - // we're attempting TLS 1.2. - if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { - continue - } - hello.cipherSuites = append(hello.cipherSuites, suiteId) - } - - _, err := io.ReadFull(config.rand(), hello.random) - if err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - - // A random session ID is used to detect when the server accepted a ticket - // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as - // a compatibility measure (see RFC 8446, Section 4.1.2). - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - - if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - } - if testingOnlyForceClientHelloSignatureAlgorithms != nil { - hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms - } - - var key *ecdh.PrivateKey - if hello.supportedVersions[0] == VersionTLS13 { - if hasAESGCMHardwareSupport { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) - } else { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) - } - - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err = generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} - } - - return hello, key, nil -} - -func (c *Conn) clientHandshake(ctx context.Context) (err error) { - - switch c.handshakeStatus { - case 0: - if c.config == nil { - c.config = defaultConfig() - } - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, ecdheKey, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) - if err != nil { - return err - } - if cacheKey != "" && session != nil { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - }() - } - - if _, err := c.writeHandshakeRecord(hello, nil); err != nil { - return err - } - c.flush() - c.handshakeStatus = 1 //已发送hello,等待下一个数据包 - - c.hs = &clientHandshakeStateTLS13{ //临时缓存 - c: c, - ctx: ctx, - hello: hello, - ecdheKey: ecdheKey, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - cacheKey: cacheKey, - } - - return nil - case 1: - hello := c.hs.(*clientHandshakeStateTLS13).hello - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - c.handshakeStatus = 2 - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion(roleClient) - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - hs13 := c.hs.(*clientHandshakeStateTLS13) - if c.vers == VersionTLS13 { - hs13.serverHello = serverHello - // In TLS 1.3, session tickets are delivered after the handshake. - return c.hs.handshake() - } - hs := &clientHandshakeState{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: hs13.session, - oldSession: hs13.session, - cacheKey: hs13.cacheKey, - } - c.hs = hs - if err := hs.handshake(); err != nil { - return err - } - case 3, 4, 5: - c.hs.handshake() - - default: - return errors.New("tls handshakeStatus error:" + strconv.Itoa(int(c.handshakeStatus))) - } - - return nil -} - -func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *ClientSessionState, earlySecret, binderKey []byte, err error) { - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil, nil - } - - hello.ticketSupported = true - - if hello.supportedVersions[0] == VersionTLS13 { - // Require DHE on resumption as it guarantees forward secrecy against - // compromise of the session ticket key. See RFC 8446, Section 4.2.9. - hello.pskModes = []uint8{pskModeDHE} - } - - // Session resumption is not allowed if renegotiating because - // renegotiation is primarily used to allow a client to send a client - // certificate, which would be skipped if session resumption occurred. - if c.handshakes != 0 { - return "", nil, nil, nil, nil - } - - // Try to resume a previously negotiated TLS session, if available. - cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - session, ok := c.config.ClientSessionCache.Get(cacheKey) - if !ok || session == nil { - return cacheKey, nil, nil, nil, nil - } - - // Check that version used for the previous session is still valid. - versOk := false - for _, v := range hello.supportedVersions { - if v == session.vers { - versOk = true - break - } - } - if !versOk { - return cacheKey, nil, nil, nil, nil - } - - // Check that the cached server certificate is not expired, and that it's - // valid for the ServerName. This should be ensured by the cache key, but - // protect the application from a faulty ClientSessionCache implementation. - if !c.config.InsecureSkipVerify { - if len(session.verifiedChains) == 0 { - // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil, nil - } - serverCert := session.serverCertificates[0] - if c.config.time().After(serverCert.NotAfter) { - // Expired certificate, delete the entry. - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil - } - if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil, nil - } - } - - if session.vers != VersionTLS13 { - // In TLS 1.2 the cipher suite must match the resumed session. Ensure we - // are still offering it. - if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil, nil - } - - hello.sessionTicket = session.sessionTicket - return - } - - // Check that the session ticket is not expired. - if c.config.time().After(session.useBy) { - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil - } - - // In TLS 1.3 the KDF hash must match the resumed session. Ensure we - // offer at least one cipher suite with that hash. - cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) - if cipherSuite == nil { - return cacheKey, nil, nil, nil, nil - } - cipherSuiteOk := false - for _, offeredID := range hello.cipherSuites { - offeredSuite := cipherSuiteTLS13ByID(offeredID) - if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return cacheKey, nil, nil, nil, nil - } - - // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. - ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) - identity := pskIdentity{ - label: session.sessionTicket, - obfuscatedTicketAge: ticketAge + session.ageAdd, - } - hello.pskIdentities = []pskIdentity{identity} - hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} - - // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. - psk := cipherSuite.expandLabel(session.masterSecret, "resumption", - session.nonce, cipherSuite.hash.Size()) - earlySecret = cipherSuite.extract(psk, nil) - binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) - transcript := cipherSuite.hash.New() - helloBytes, err := hello.marshalWithoutBinders() - if err != nil { - return "", nil, nil, nil, err - } - transcript.Write(helloBytes) - pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - if err := hello.updateBinders(pskBinders); err != nil { - return "", nil, nil, nil, err - } - - return -} - -func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { - peerVersion := serverHello.vers - if serverHello.supportedVersion != 0 { - peerVersion = serverHello.supportedVersion - } - - vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion}) - if !ok { - c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) - } - - c.vers = vers - c.haveVers = true - c.in.version = vers - c.out.version = vers - - return nil -} - -// Does the handshake, either a full one or resumes old session. Requires hs.c, -// hs.hello, hs.serverHello, and, optionally, hs.session to be set. -func (hs *clientHandshakeState) handshake() error { - c := hs.c - - if c.handshakeStatus == 2 { - isResume, err := hs.processServerHello() - c.didResume = isResume - - if err != nil { - return err - } - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - - // No signatures of the handshake are needed in a resumption. - // Otherwise, in a full handshake, if we don't have any certificates - // configured then we will never send a CertificateVerify message and - // thus no signatures are needed in that case either. - if c.didResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { - hs.finishedHash.discardHandshakeBuffer() - } - - if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { - return err - } - if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { - return err - } - c.handshakeStatus = 3 - c.buffering = true - } - - if c.didResume { - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = false - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } else { - switch c.handshakeStatus { - case 3: - if err := hs.doFullHandshakeStep1(); err != nil { - return err - } - c.handshakeStatus = 4 - return nil - case 4: - if err := hs.doFullHandshakeStep2(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - _, err := c.flush() - - c.handshakeStatus = 5 - return err - case 5: - c.clientFinishedIsFirst = true - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - - c.handshakeStatus = 255 - // If we had a successful handshake and hs.session is different from - // the one already cached - cache a new one. - - if hs.cacheKey != "" && hs.session != nil && hs.oldSession != hs.session { - c.config.ClientSessionCache.Put(hs.cacheKey, hs.session) - } - // Enable kernel TLS if possible - if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeState) pickCipherSuite() error { - if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { - hs.c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server chose an unconfigured cipher suite") - } - - hs.c.cipherSuite = hs.suite.id - return nil -} - -func (hs *clientHandshakeState) doFullHandshakeStep1() error { - c := hs.c - - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - certMsg, ok := msg.(*certificateMsg) - if !ok || len(certMsg.certificates) == 0 { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - - msg, err = c.readHandshake(&hs.finishedHash) - if c.handshakes == 1 || len(c.peerCertificates) == 0 { - // If this is the first handshake on a connection, process and - // (optionally) verify the server's certificates. - if err := c.verifyServerCertificate(certMsg.certificates); err != nil { - return err - } - } else { - // This is a renegotiation handshake. We require that the - // server's identity (i.e. leaf certificate) is unchanged and - // thus any previous trust decision is still valid. - // - // See https://mitls.org/pages/attacks/3SHAKE for the - // motivation behind this requirement. - if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: server's identity changed during renegotiation") - } - } - return nil -} -func (hs *clientHandshakeState) doFullHandshakeStep2() error { - c := hs.c - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - - cs, ok := msg.(*certificateStatusMsg) - if ok { - // RFC4366 on Certificate Status Request: - // The server MAY return a "certificate_status" message. - - if !hs.serverHello.ocspStapling { - // If a server returns a "CertificateStatus" message, then the - // server MUST have included an extension of type "status_request" - // with empty "extension_data" in the extended server hello. - - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received unexpected CertificateStatus message") - } - - c.ocspResponse = cs.response - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - - keyAgreement := hs.suite.ka(c.vers) - - skx, ok := msg.(*serverKeyExchangeMsg) - if ok { - err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) - if err != nil { - c.sendAlert(alertUnexpectedMessage) - return err - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - - var chainToSend *Certificate - var certRequested bool - certReq, ok := msg.(*certificateRequestMsg) - if ok { - certRequested = true - - cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) - if chainToSend, err = c.getClientCertificate(cri); err != nil { - c.sendAlert(alertInternalError) - return err - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - - shd, ok := msg.(*serverHelloDoneMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(shd, msg) - } - - // If the server requested a certificate then we have to send a - // Certificate message, even if it's empty because we don't have a - // certificate to send. - if certRequested { - certMsg := new(certificateMsg) - certMsg.certificates = chainToSend.Certificate - if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { - return err - } - } - - preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - if ckx != nil { - if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { - return err - } - } - - if chainToSend != nil && len(chainToSend.Certificate) > 0 { - certVerify := &certificateVerifyMsg{} - - key, ok := chainToSend.PrivateKey.(crypto.Signer) - if !ok { - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - certVerify.hasSignatureAlgorithm = true - certVerify.signatureAlgorithm = signatureAlgorithm - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { - return err - } - } - - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to write to key log: " + err.Error()) - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *clientHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - var clientCipher, serverCipher interface{} - var clientHash, serverHash hash.Hash - if hs.suite.cipher != nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.key, c.in.iv = serverKey, serverIV - c.out.key, c.out.iv = clientKey, clientIV - c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) - c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) - return nil -} - -func (hs *clientHandshakeState) serverResumedSession() bool { - // If the server responded with the same sessionId then it means the - // sessionTicket is being used to resume a TLS session. - return hs.session != nil && hs.hello.sessionId != nil && - bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) -} - -func (hs *clientHandshakeState) processServerHello() (bool, error) { - c := hs.c - - if err := hs.pickCipherSuite(); err != nil { - return false, err - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertUnexpectedMessage) - return false, errors.New("tls: server selected unsupported compression format") - } - - if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { - c.secureRenegotiation = true - if len(hs.serverHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: initial handshake had non-empty renegotiation extension") - } - } - - if c.handshakes > 0 && c.secureRenegotiation { - var expectedSecureRenegotiation [24]byte - copy(expectedSecureRenegotiation[:], c.clientFinished[:]) - copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) - if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: incorrect renegotiation extension contents") - } - } - - if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { - c.sendAlert(alertUnsupportedExtension) - return false, err - } - c.clientProtocol = hs.serverHello.alpnProtocol - - c.scts = hs.serverHello.scts - - if !hs.serverResumedSession() { - return false, nil - } - - if hs.session.vers != c.vers { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different version") - } - - if hs.session.cipherSuite != hs.suite.id { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different cipher suite") - } - - // Restore masterSecret, peerCerts, and ocspResponse from previous state - hs.masterSecret = hs.session.masterSecret - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - // Let the ServerHello SCTs override the session SCTs from the original - // connection, if any are provided - if len(c.scts) == 0 && len(hs.session.scts) != 0 { - c.scts = hs.session.scts - } - - return true, nil -} - -// checkALPN ensure that the server's choice of ALPN protocol is compatible with -// the protocols that we advertised in the Client Hello. -func checkALPN(clientProtos []string, serverProto string) error { - if serverProto == "" { - return nil - } - if len(clientProtos) == 0 { - return errors.New("tls: server advertised unrequested ALPN extension") - } - for _, proto := range clientProtos { - if proto == serverProto { - return nil - } - } - return errors.New("tls: server selected unadvertised ALPN protocol") -} - -func (hs *clientHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - // finishedMsg is included in the transcript, but not until after we - // check the client version, since the state before this message was - // sent is used during verification. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - serverFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverFinished, msg) - } - - verify := hs.finishedHash.serverSum(hs.masterSecret) - if len(verify) != len(serverFinished.verifyData) || - subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server's Finished message was incorrect") - } - - if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { - return err - } - - copy(out, verify) - return nil -} - -func (hs *clientHandshakeState) readSessionTicket() error { - if !hs.serverHello.ticketSupported { - return nil - } - - c := hs.c - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - sessionTicketMsg, ok := msg.(*newSessionTicketMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(sessionTicketMsg, msg) - } - - hs.session = &ClientSessionState{ - sessionTicket: sessionTicketMsg.ticket, - vers: c.vers, - cipherSuite: hs.suite.id, - masterSecret: hs.masterSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - return nil -} - -func (hs *clientHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if err := c.writeChangeCipherRecord(); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { - return err - } - copy(out, finished.verifyData) - return nil -} - -// verifyServerCertificate parses and verifies the provided chain, setting -// c.verifiedChains and c.peerCertificates or sending the appropriate alert. -func (c *Conn) verifyServerCertificate(certificates [][]byte) error { - activeHandles := make([]*activeCert, len(certificates)) - certs := make([]*x509.Certificate, len(certificates)) - for i, asn1Data := range certificates { - cert, err := clientCertCache.newCert(asn1Data) - if err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse certificate from server: " + err.Error()) - } - activeHandles[i] = cert - certs[i] = cert.cert - } - - if !c.config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: c.config.RootCAs, - CurrentTime: c.config.time(), - DNSName: c.config.ServerName, - Intermediates: x509.NewCertPool(), - } - - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - var err error - c.verifiedChains, err = certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} - } - } - - switch certs[0].PublicKey.(type) { - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: - break - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) - } - - c.activeCertHandles = activeHandles - c.peerCertificates = certs - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS -// <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { - cri := &CertificateRequestInfo{ - AcceptableCAs: certReq.certificateAuthorities, - Version: vers, - ctx: ctx, - } - - var rsaAvail, ecAvail bool - for _, certType := range certReq.certificateTypes { - switch certType { - case certTypeRSASign: - rsaAvail = true - case certTypeECDSASign: - ecAvail = true - } - } - - if !certReq.hasSignatureAlgorithm { - // Prior to TLS 1.2, signature schemes did not exist. In this case we - // make up a list based on the acceptable certificate types, to help - // GetClientCertificate and SupportsCertificate select the right certificate. - // The hash part of the SignatureScheme is a lie here, because - // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. - switch { - case rsaAvail && ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case rsaAvail: - cri.SignatureSchemes = []SignatureScheme{ - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - } - } - return cri - } - - // Filter the signature schemes based on the certificate types. - // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). - cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) - for _, sigScheme := range certReq.supportedSignatureAlgorithms { - sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) - if err != nil { - continue - } - switch sigType { - case signatureECDSA, signatureEd25519: - if ecAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - case signatureRSAPSS, signaturePKCS1v15: - if rsaAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - } - } - - return cri -} - -func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { - if c.config.GetClientCertificate != nil { - return c.config.GetClientCertificate(cri) - } - - for _, chain := range c.config.Certificates { - if err := cri.SupportsCertificate(&chain); err != nil { - continue - } - return &chain, nil - } - - // No acceptable certificate found. Don't send a certificate. - return new(Certificate), nil -} - -// clientSessionCacheKey returns a key used to cache sessionTickets that could -// be used to resume previously negotiated TLS sessions with a server. -func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { - if len(config.ServerName) > 0 { - return config.ServerName - } - return serverAddr.String() -} - -// hostnameInSNI converts name into an appropriate hostname for SNI. -// Literal IP addresses and absolute FQDNs are not permitted as SNI values. -// See RFC 6066, Section 3. -func hostnameInSNI(name string) string { - host := name - if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { - host = host[1 : len(host)-1] - } - if i := strings.LastIndex(host, "%"); i > 0 { - host = host[:i] - } - if net.ParseIP(host) != nil { - return "" - } - for len(name) > 0 && name[len(name)-1] == '.' { - name = name[:len(name)-1] - } - return name -} diff --git a/pkg/tls/handshake_client_tls13.go b/pkg/tls/handshake_client_tls13.go deleted file mode 100644 index 3e83198b6..000000000 --- a/pkg/tls/handshake_client_tls13.go +++ /dev/null @@ -1,713 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdh" - "crypto/hmac" - "crypto/rsa" - "errors" - "hash" - "time" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey - - session *ClientSessionState - earlySecret []byte - binderKey []byte - - certReq *certificateRequestMsgTLS13 - usingPSK bool - sentDummyCCS bool - suite *cipherSuiteTLS13 - transcript hash.Hash - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 - cacheKey string -} - -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and, -// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. -func (hs *clientHandshakeStateTLS13) handshake() error { - c := hs.c - - if needFIPS() { - return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") - } - - // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, - // sections 4.1.2 and 4.1.3. - if c.handshakes > 255 { - return errors.New("tls: server selected TLS 1.3 in a renegotiation") - } - - // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 { - return c.sendAlert(alertInternalError) - } - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - hs.transcript = hs.suite.hash.New() - - if err := transcriptMsg(hs.hello, hs.transcript); err != nil { - return err - } - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.processHelloRetryRequest(); err != nil { - return err - } - } - - if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { - return err - } - - c.buffering = true - if err := hs.processServerHello(); err != nil { - return err - } - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.establishHandshakeKeys(); err != nil { - return err - } - if err := hs.readServerParameters(); err != nil { - return err - } - if err := hs.readServerCertificate(); err != nil { - return err - } - if err := hs.readServerFinished(); err != nil { - return err - } - if err := hs.sendClientCertificate(); err != nil { - return err - } - if err := hs.sendClientFinished(); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - - c.handshakeStatus = 255 - // Enable kernel TLS if possible - if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { - return err - } - - return nil -} - -// checkServerHelloOrHRR does validity checks that apply to both ServerHello and -// HelloRetryRequest messages. It sets hs.suite. -func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { - c := hs.c - - if hs.serverHello.supportedVersion == 0 { - c.sendAlert(alertMissingExtension) - return errors.New("tls: server selected TLS 1.3 using the legacy version field") - } - - if hs.serverHello.supportedVersion != VersionTLS13 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid version after a HelloRetryRequest") - } - - if hs.serverHello.vers != VersionTLS12 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an incorrect legacy version") - } - - if hs.serverHello.ocspStapling || - hs.serverHello.ticketSupported || - hs.serverHello.secureRenegotiationSupported || - len(hs.serverHello.secureRenegotiation) != 0 || - len(hs.serverHello.alpnProtocol) != 0 || - len(hs.serverHello.scts) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") - } - - if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not echo the legacy session ID") - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported compression format") - } - - selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) - if hs.suite != nil && selectedSuite != hs.suite { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server changed cipher suite after a HelloRetryRequest") - } - if selectedSuite == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server chose an unconfigured cipher suite") - } - hs.suite = selectedSuite - c.cipherSuite = hs.suite.id - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - return hs.c.writeChangeCipherRecord() -} - -// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and -// resends hs.hello, and reads the new ServerHello into hs.serverHello. -func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. (The idea is that the server might offload transcript - // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { - return err - } - - // The only HelloRetryRequest extensions we support are key_share and - // cookie, and clients must abort the handshake if the HRR would not result - // in any change in the ClientHello. - if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest message") - } - - if hs.serverHello.cookie != nil { - hs.hello.cookie = hs.serverHello.cookie - } - - if hs.serverHello.serverShare.group != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received malformed key_share extension") - } - - // If the server sent a key_share extension selecting a group, ensure it's - // a group we advertised but did not send a key share for, and send a key - // share for it this time. - if curveID := hs.serverHello.selectedGroup; curveID != 0 { - curveOK := false - for _, id := range hs.hello.supportedCurves { - if id == curveID { - curveOK = true - break - } - } - if !curveOK { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") - } - if _, ok := curveForCurveID(curveID); !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err := generateECDHEKey(c.config.rand(), curveID) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.ecdheKey = key - hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} - } - - hs.hello.raw = nil - if len(hs.hello.pskIdentities) > 0 { - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash == hs.suite.hash { - // Update binders and obfuscated_ticket_age. - ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) - hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd - - transcript := hs.suite.hash.New() - transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - transcript.Write(chHash) - if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { - return err - } - helloBytes, err := hs.hello.marshalWithoutBinders() - if err != nil { - return err - } - transcript.Write(helloBytes) - pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - if err := hs.hello.updateBinders(pskBinders); err != nil { - return err - } - } else { - // Server selected a cipher suite incompatible with the PSK. - hs.hello.pskIdentities = nil - hs.hello.pskBinders = nil - } - } - - if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { - return err - } - - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - hs.serverHello = serverHello - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) processServerHello() error { - c := hs.c - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: server sent two HelloRetryRequest messages") - } - - if len(hs.serverHello.cookie) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a cookie in a normal ServerHello") - } - - if hs.serverHello.selectedGroup != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: malformed key_share extension") - } - - if hs.serverHello.serverShare.group == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not send a key share") - } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - - if !hs.serverHello.selectedIdentityPresent { - return nil - } - - if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK") - } - - if len(hs.hello.pskIdentities) != 1 || hs.session == nil { - return c.sendAlert(alertInternalError) - } - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash != hs.suite.hash { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK and cipher suite pair") - } - - hs.usingPSK = true - c.didResume = true - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - c.scts = hs.session.scts - return nil -} - -func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { - c := hs.c - - peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - sharedKey, err := hs.ecdheKey.ECDH(peerKey) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - - earlySecret := hs.earlySecret - if !hs.usingPSK { - earlySecret = hs.suite.extract(nil, nil) - } - - handshakeSecret := hs.suite.extract(sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, clientSecret) - serverSecret := hs.suite.deriveSecret(handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, serverSecret) - - err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(handshakeSecret, "derived", nil)) - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerParameters() error { - c := hs.c - - msg, err := c.readHandshake(hs.transcript) - if err != nil { - return err - } - - encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(encryptedExtensions, msg) - } - - if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { - c.sendAlert(alertUnsupportedExtension) - return err - } - c.clientProtocol = encryptedExtensions.alpnProtocol - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerCertificate() error { - c := hs.c - - // Either a PSK or a certificate is always used, but not both. - // See RFC 8446, Section 4.1.1. - if hs.usingPSK { - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - msg, err := c.readHandshake(hs.transcript) - if err != nil { - return err - } - - certReq, ok := msg.(*certificateRequestMsgTLS13) - if ok { - hs.certReq = certReq - - msg, err = c.readHandshake(hs.transcript) - if err != nil { - return err - } - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - if len(certMsg.certificate.Certificate) == 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received empty certificates message") - } - - c.scts = certMsg.certificate.SignedCertificateTimestamps - c.ocspResponse = certMsg.certificate.OCSPStaple - - if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { - return err - } - - // certificateVerifyMsg is included in the transcript, but not until - // after we verify the handshake signature, since the state before - // this message was sent is used. - msg, err = c.readHandshake(nil) - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - - if err := transcriptMsg(certVerify, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerFinished() error { - c := hs.c - - // finishedMsg is included in the transcript, but not until after we - // check the client version, since the state before this message was - // sent is used during verification. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - if !hmac.Equal(expectedMAC, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid server finished hash") - } - - if err := transcriptMsg(finished, hs.transcript); err != nil { - return err - } - - // Derive secrets that take context through the server Finished. - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, serverSecret) - - err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { - c := hs.c - - if hs.certReq == nil { - return nil - } - - cert, err := c.getClientCertificate(&CertificateRequestInfo{ - AcceptableCAs: hs.certReq.certificateAuthorities, - SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, - Version: c.vers, - ctx: hs.ctx, - }) - if err != nil { - return err - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *cert - certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - - if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { - return err - } - - // If we sent an empty certificate message, skip the CertificateVerify. - if len(cert.Certificate) == 0 { - return nil - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - - certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) - if err != nil { - // getClientCertificate returned a certificate incompatible with the - // CertificateRequestInfo supported signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { - return err - } - - c.out.setTrafficSecret(hs.suite, hs.trafficSecret) - - if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - } - - return nil -} - -func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { - if !c.isClient { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received new session ticket from a client") - } - - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return nil - } - - // See RFC 8446, Section 4.6.1. - if msg.lifetime == 0 { - return nil - } - lifetime := time.Duration(msg.lifetime) * time.Second - if lifetime > maxSessionTicketLifetime { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: received a session ticket with invalid lifetime") - } - - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil || c.resumptionSecret == nil { - return c.sendAlert(alertInternalError) - } - - // Save the resumption_master_secret and nonce instead of deriving the PSK - // to do the least amount of work on NewSessionTicket messages before we - // know if the ticket will be used. Forward secrecy of resumed connections - // is guaranteed by the requirement for pskModeDHE. - session := &ClientSessionState{ - sessionTicket: msg.label, - vers: c.vers, - cipherSuite: c.cipherSuite, - masterSecret: c.resumptionSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - nonce: msg.nonce, - useBy: c.config.time().Add(lifetime), - ageAdd: msg.ageAdd, - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - c.config.ClientSessionCache.Put(cacheKey, session) - - return nil -} diff --git a/pkg/tls/handshake_messages.go b/pkg/tls/handshake_messages.go deleted file mode 100644 index 695aacf12..000000000 --- a/pkg/tls/handshake_messages.go +++ /dev/null @@ -1,1852 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "errors" - "fmt" - "strings" - - "golang.org/x/crypto/cryptobyte" -) - -// The marshalingFunction type is an adapter to allow the use of ordinary -// functions as cryptobyte.MarshalingValue. -type marshalingFunction func(b *cryptobyte.Builder) error - -func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { - return f(b) -} - -// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If -// the length of the sequence is not the value specified, it produces an error. -func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { - b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { - if len(v) != n { - return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) - } - b.AddBytes(v) - return nil - })) -} - -// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. -func addUint64(b *cryptobyte.Builder, v uint64) { - b.AddUint32(uint32(v >> 32)) - b.AddUint32(uint32(v)) -} - -// readUint64 decodes a big-endian, 64-bit value into out and advances over it. -// It reports whether the read was successful. -func readUint64(s *cryptobyte.String, out *uint64) bool { - var hi, lo uint32 - if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { - return false - } - *out = uint64(hi)<<32 | uint64(lo) - return true -} - -// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) -} - -type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string - ocspStapling bool - supportedCurves []CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocols []string - scts bool - supportedVersions []uint16 - cookie []byte - keyShares []keyShare - earlyData bool - pskModes []uint8 - pskIdentities []pskIdentity - pskBinders [][]byte -} - -func (m *clientHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var exts cryptobyte.Builder - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - exts.AddUint16(extensionServerName) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8(0) // name_type = host_name - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - exts.AddUint16(extensionStatusRequest) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8(1) // status_type = ocsp - exts.AddUint16(0) // empty responder_id_list - exts.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - exts.AddUint16(extensionSupportedCurves) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - exts.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - exts.AddUint16(extensionSupportedPoints) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - exts.AddUint16(extensionSessionTicket) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - exts.AddUint16(extensionSignatureAlgorithms) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - exts.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - exts.AddUint16(extensionSignatureAlgorithmsCert) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - exts.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - exts.AddUint16(extensionRenegotiationInfo) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - exts.AddUint16(extensionALPN) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - exts.AddUint16(extensionSCT) - exts.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - exts.AddUint16(extensionSupportedVersions) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - exts.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - exts.AddUint16(extensionCookie) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - exts.AddUint16(extensionKeyShare) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, ks := range m.keyShares { - exts.AddUint16(uint16(ks.group)) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - exts.AddUint16(extensionEarlyData) - exts.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - exts.AddUint16(extensionPSKModes) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.pskModes) - }) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - exts.AddUint16(extensionPreSharedKey) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(psk.label) - }) - exts.AddUint32(psk.obfuscatedTicketAge) - } - }) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(binder) - }) - } - }) - }) - } - extBytes, err := exts.Bytes() - if err != nil { - return nil, err - } - - var b cryptobyte.Builder - b.AddUint8(typeClientHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, suite := range m.cipherSuites { - b.AddUint16(suite) - } - }) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.compressionMethods) - }) - - if len(extBytes) > 0 { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(extBytes) - }) - } - }) - - m.raw, err = b.Bytes() - return m.raw, err -} - -// marshalWithoutBinders returns the ClientHello through the -// PreSharedKeyExtension.identities field, according to RFC 8446, Section -// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { - bindersLen := 2 // uint16 length prefix - for _, binder := range m.pskBinders { - bindersLen += 1 // uint8 length prefix - bindersLen += len(binder) - } - - fullMessage, err := m.marshal() - if err != nil { - return nil, err - } - return fullMessage[:len(fullMessage)-bindersLen], nil -} - -// updateBinders updates the m.pskBinders field, if necessary updating the -// cached marshaled representation. The supplied binders must have the same -// length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { - if len(pskBinders) != len(m.pskBinders) { - return errors.New("tls: internal error: pskBinders length mismatch") - } - for i := range m.pskBinders { - if len(pskBinders[i]) != len(m.pskBinders[i]) { - return errors.New("tls: internal error: pskBinders length mismatch") - } - } - m.pskBinders = pskBinders - if m.raw != nil { - helloBytes, err := m.marshalWithoutBinders() - if err != nil { - return err - } - lenWithoutBinders := len(helloBytes) - b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - return errors.New("tls: internal error: failed to update binders") - } - } - - return nil -} - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - *m = clientHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) { - return false - } - - var cipherSuites cryptobyte.String - if !s.ReadUint16LengthPrefixed(&cipherSuites) { - return false - } - m.cipherSuites = []uint16{} - m.secureRenegotiationSupported = false - for !cipherSuites.Empty() { - var suite uint16 - if !cipherSuites.ReadUint16(&suite) { - return false - } - if suite == scsvRenegotiation { - m.secureRenegotiationSupported = true - } - m.cipherSuites = append(m.cipherSuites, suite) - } - - if !readUint8LengthPrefixed(&s, &m.compressionMethods) { - return false - } - - if s.Empty() { - // ClientHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - seenExts := make(map[uint16]bool) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - if seenExts[extension] { - return false - } - seenExts[extension] = true - - switch extension { - case extensionServerName: - // RFC 6066, Section 3 - var nameList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { - return false - } - for !nameList.Empty() { - var nameType uint8 - var serverName cryptobyte.String - if !nameList.ReadUint8(&nameType) || - !nameList.ReadUint16LengthPrefixed(&serverName) || - serverName.Empty() { - return false - } - if nameType != 0 { - continue - } - if len(m.serverName) != 0 { - // Multiple names of the same name_type are prohibited. - return false - } - m.serverName = string(serverName) - // An SNI value may not include a trailing dot. - if strings.HasSuffix(m.serverName, ".") { - return false - } - } - case extensionStatusRequest: - // RFC 4366, Section 3.6 - var statusType uint8 - var ignored cryptobyte.String - if !extData.ReadUint8(&statusType) || - !extData.ReadUint16LengthPrefixed(&ignored) || - !extData.ReadUint16LengthPrefixed(&ignored) { - return false - } - m.ocspStapling = statusType == statusTypeOCSP - case extensionSupportedCurves: - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - var curves cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { - return false - } - for !curves.Empty() { - var curve uint16 - if !curves.ReadUint16(&curve) { - return false - } - m.supportedCurves = append(m.supportedCurves, CurveID(curve)) - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - case extensionSessionTicket: - // RFC 5077, Section 3.2 - m.ticketSupported = true - extData.ReadBytes(&m.sessionTicket, len(extData)) - case extensionSignatureAlgorithms: - // RFC 5246, Section 7.4.1.4.1 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - // RFC 8446, Section 4.2.3 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionRenegotiationInfo: - // RFC 5746, Section 3.2 - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - // RFC 7301, Section 3.1 - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - for !protoList.Empty() { - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { - return false - } - m.alpnProtocols = append(m.alpnProtocols, string(proto)) - } - case extensionSCT: - // RFC 6962, Section 3.3.1 - m.scts = true - case extensionSupportedVersions: - // RFC 8446, Section 4.2.1 - var versList cryptobyte.String - if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { - return false - } - for !versList.Empty() { - var vers uint16 - if !versList.ReadUint16(&vers) { - return false - } - m.supportedVersions = append(m.supportedVersions, vers) - } - case extensionCookie: - // RFC 8446, Section 4.2.2 - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // RFC 8446, Section 4.2.8 - var clientShares cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&clientShares) { - return false - } - for !clientShares.Empty() { - var ks keyShare - if !clientShares.ReadUint16((*uint16)(&ks.group)) || - !readUint16LengthPrefixed(&clientShares, &ks.data) || - len(ks.data) == 0 { - return false - } - m.keyShares = append(m.keyShares, ks) - } - case extensionEarlyData: - // RFC 8446, Section 4.2.10 - m.earlyData = true - case extensionPSKModes: - // RFC 8446, Section 4.2.9 - if !readUint8LengthPrefixed(&extData, &m.pskModes) { - return false - } - case extensionPreSharedKey: - // RFC 8446, Section 4.2.11 - if !extensions.Empty() { - return false // pre_shared_key must be the last extension - } - var identities cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { - return false - } - for !identities.Empty() { - var psk pskIdentity - if !readUint16LengthPrefixed(&identities, &psk.label) || - !identities.ReadUint32(&psk.obfuscatedTicketAge) || - len(psk.label) == 0 { - return false - } - m.pskIdentities = append(m.pskIdentities, psk) - } - var binders cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { - return false - } - for !binders.Empty() { - var binder []byte - if !readUint8LengthPrefixed(&binders, &binder) || - len(binder) == 0 { - return false - } - m.pskBinders = append(m.pskBinders, binder) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type serverHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuite uint16 - compressionMethod uint8 - ocspStapling bool - ticketSupported bool - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocol string - scts [][]byte - supportedVersion uint16 - serverShare keyShare - selectedIdentityPresent bool - selectedIdentity uint16 - supportedPoints []uint8 - - // HelloRetryRequest extensions - cookie []byte - selectedGroup CurveID -} - -func (m *serverHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var exts cryptobyte.Builder - if m.ocspStapling { - exts.AddUint16(extensionStatusRequest) - exts.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - exts.AddUint16(extensionSessionTicket) - exts.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - exts.AddUint16(extensionRenegotiationInfo) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - exts.AddUint16(extensionALPN) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - exts.AddUint16(extensionSCT) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, sct := range m.scts { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - exts.AddUint16(extensionSupportedVersions) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - exts.AddUint16(extensionKeyShare) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(uint16(m.serverShare.group)) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - exts.AddUint16(extensionPreSharedKey) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - exts.AddUint16(extensionCookie) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - exts.AddUint16(extensionKeyShare) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - exts.AddUint16(extensionSupportedPoints) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.supportedPoints) - }) - }) - } - - extBytes, err := exts.Bytes() - if err != nil { - return nil, err - } - - var b cryptobyte.Builder - b.AddUint8(typeServerHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16(m.cipherSuite) - b.AddUint8(m.compressionMethod) - - if len(extBytes) > 0 { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(extBytes) - }) - } - }) - - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *serverHelloMsg) unmarshal(data []byte) bool { - *m = serverHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) || - !s.ReadUint16(&m.cipherSuite) || - !s.ReadUint8(&m.compressionMethod) { - return false - } - - if s.Empty() { - // ServerHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - seenExts := make(map[uint16]bool) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - if seenExts[extension] { - return false - } - seenExts[extension] = true - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSessionTicket: - m.ticketSupported = true - case extensionRenegotiationInfo: - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - m.scts = append(m.scts, sct) - } - case extensionSupportedVersions: - if !extData.ReadUint16(&m.supportedVersion) { - return false - } - case extensionCookie: - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // This extension has different formats in SH and HRR, accept either - // and let the handshake logic decide. See RFC 8446, Section 4.2.8. - if len(extData) == 2 { - if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { - return false - } - } else { - if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || - !readUint16LengthPrefixed(&extData, &m.serverShare.data) { - return false - } - } - case extensionPreSharedKey: - m.selectedIdentityPresent = true - if !extData.ReadUint16(&m.selectedIdentity) { - return false - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type encryptedExtensionsMsg struct { - raw []byte - alpnProtocol string -} - -func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeEncryptedExtensions) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { - *m = encryptedExtensionsMsg{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type endOfEarlyDataMsg struct{} - -func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { - x := make([]byte, 4) - x[0] = typeEndOfEarlyData - return x, nil -} - -func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type keyUpdateMsg struct { - raw []byte - updateRequested bool -} - -func (m *keyUpdateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeKeyUpdate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.updateRequested { - b.AddUint8(1) - } else { - b.AddUint8(0) - } - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *keyUpdateMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var updateRequested uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&updateRequested) || !s.Empty() { - return false - } - switch updateRequested { - case 0: - m.updateRequested = false - case 1: - m.updateRequested = true - default: - return false - } - return true -} - -type newSessionTicketMsgTLS13 struct { - raw []byte - lifetime uint32 - ageAdd uint32 - nonce []byte - label []byte - maxEarlyData uint32 -} - -func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeNewSessionTicket) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.lifetime) - b.AddUint32(m.ageAdd) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.nonce) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.label) - }) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.maxEarlyData > 0 { - b.AddUint16(extensionEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.maxEarlyData) - }) - } - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { - *m = newSessionTicketMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint32(&m.lifetime) || - !s.ReadUint32(&m.ageAdd) || - !readUint8LengthPrefixed(&s, &m.nonce) || - !readUint16LengthPrefixed(&s, &m.label) || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionEarlyData: - if !extData.ReadUint32(&m.maxEarlyData) { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateRequestMsgTLS13 struct { - raw []byte - ocspStapling bool - scts bool - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateRequest) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - // certificate_request_context (SHALL be zero length unless used for - // post-handshake authentication) - b.AddUint8(0) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.scts { - // RFC 8446, Section 4.4.2.1 makes no mention of - // signed_certificate_timestamp in CertificateRequest, but - // "Extensions in the Certificate message from the client MUST - // correspond to extensions in the CertificateRequest message - // from the server." and it appears in the table in Section 4.2. - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedSignatureAlgorithms) > 0 { - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.certificateAuthorities) > 0 { - b.AddUint16(extensionCertificateAuthorities) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ca := range m.certificateAuthorities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ca) - }) - } - }) - }) - } - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { - *m = certificateRequestMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context, extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSCT: - m.scts = true - case extensionSignatureAlgorithms: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionCertificateAuthorities: - var auths cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { - return false - } - for !auths.Empty() { - var ca []byte - if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { - return false - } - m.certificateAuthorities = append(m.certificateAuthorities, ca) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateMsg struct { - raw []byte - certificates [][]byte -} - -func (m *certificateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var i int - for _, slice := range m.certificates { - i += len(slice) - } - - length := 3 + 3*len(m.certificates) + i - x := make([]byte, 4+length) - x[0] = typeCertificate - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - certificateOctets := length - 3 - x[4] = uint8(certificateOctets >> 16) - x[5] = uint8(certificateOctets >> 8) - x[6] = uint8(certificateOctets) - - y := x[7:] - for _, slice := range m.certificates { - y[0] = uint8(len(slice) >> 16) - y[1] = uint8(len(slice) >> 8) - y[2] = uint8(len(slice)) - copy(y[3:], slice) - y = y[3+len(slice):] - } - - m.raw = x - return m.raw, nil -} - -func (m *certificateMsg) unmarshal(data []byte) bool { - if len(data) < 7 { - return false - } - - m.raw = data - certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) - if uint32(len(data)) != certsLen+7 { - return false - } - - numCerts := 0 - d := data[7:] - for certsLen > 0 { - if len(d) < 4 { - return false - } - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - if uint32(len(d)) < 3+certLen { - return false - } - d = d[3+certLen:] - certsLen -= 3 + certLen - numCerts++ - } - - m.certificates = make([][]byte, numCerts) - d = data[7:] - for i := 0; i < numCerts; i++ { - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - m.certificates[i] = d[3 : 3+certLen] - d = d[3+certLen:] - } - - return true -} - -type certificateMsgTLS13 struct { - raw []byte - certificate Certificate - ocspStapling bool - scts bool -} - -func (m *certificateMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // certificate_request_context - - certificate := m.certificate - if !m.ocspStapling { - certificate.OCSPStaple = nil - } - if !m.scts { - certificate.SignedCertificateTimestamps = nil - } - marshalCertificate(b, certificate) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for i, cert := range certificate.Certificate { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if i > 0 { - // This library only supports OCSP and SCT for leaf certificates. - return - } - if certificate.OCSPStaple != nil { - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(certificate.OCSPStaple) - }) - }) - } - if certificate.SignedCertificateTimestamps != nil { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range certificate.SignedCertificateTimestamps { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - }) - } - }) -} - -func (m *certificateMsgTLS13) unmarshal(data []byte) bool { - *m = certificateMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !unmarshalCertificate(&s, &m.certificate) || - !s.Empty() { - return false - } - - m.scts = m.certificate.SignedCertificateTimestamps != nil - m.ocspStapling = m.certificate.OCSPStaple != nil - - return true -} - -func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - var extensions cryptobyte.String - if !readUint24LengthPrefixed(&certList, &cert) || - !certList.ReadUint16LengthPrefixed(&extensions) { - return false - } - certificate.Certificate = append(certificate.Certificate, cert) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - if len(certificate.Certificate) > 1 { - // This library only supports OCSP and SCT for leaf certificates. - continue - } - - switch extension { - case extensionStatusRequest: - var statusType uint8 - if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || - len(certificate.OCSPStaple) == 0 { - return false - } - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - certificate.SignedCertificateTimestamps = append( - certificate.SignedCertificateTimestamps, sct) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - } - return true -} - -type serverKeyExchangeMsg struct { - raw []byte - key []byte -} - -func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - length := len(m.key) - x := make([]byte, length+4) - x[0] = typeServerKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.key) - - m.raw = x - return x, nil -} - -func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - m.key = data[4:] - return true -} - -type certificateStatusMsg struct { - raw []byte - response []byte -} - -func (m *certificateStatusMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateStatus) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.response) - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *certificateStatusMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var statusType uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&s, &m.response) || - len(m.response) == 0 || !s.Empty() { - return false - } - return true -} - -type serverHelloDoneMsg struct{} - -func (m *serverHelloDoneMsg) marshal() ([]byte, error) { - x := make([]byte, 4) - x[0] = typeServerHelloDone - return x, nil -} - -func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type clientKeyExchangeMsg struct { - raw []byte - ciphertext []byte -} - -func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - length := len(m.ciphertext) - x := make([]byte, length+4) - x[0] = typeClientKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.ciphertext) - - m.raw = x - return x, nil -} - -func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if l != len(data)-4 { - return false - } - m.ciphertext = data[4:] - return true -} - -type finishedMsg struct { - raw []byte - verifyData []byte -} - -func (m *finishedMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeFinished) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.verifyData) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - return s.Skip(1) && - readUint24LengthPrefixed(&s, &m.verifyData) && - s.Empty() -} - -type certificateRequestMsg struct { - raw []byte - // hasSignatureAlgorithm indicates whether this message includes a list of - // supported signature algorithms. This change was introduced with TLS 1.2. - hasSignatureAlgorithm bool - - certificateTypes []byte - supportedSignatureAlgorithms []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - // See RFC 4346, Section 7.4.4. - length := 1 + len(m.certificateTypes) + 2 - casLength := 0 - for _, ca := range m.certificateAuthorities { - casLength += 2 + len(ca) - } - length += casLength - - if m.hasSignatureAlgorithm { - length += 2 + 2*len(m.supportedSignatureAlgorithms) - } - - x := make([]byte, 4+length) - x[0] = typeCertificateRequest - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - x[4] = uint8(len(m.certificateTypes)) - - copy(x[5:], m.certificateTypes) - y := x[5+len(m.certificateTypes):] - - if m.hasSignatureAlgorithm { - n := len(m.supportedSignatureAlgorithms) * 2 - y[0] = uint8(n >> 8) - y[1] = uint8(n) - y = y[2:] - for _, sigAlgo := range m.supportedSignatureAlgorithms { - y[0] = uint8(sigAlgo >> 8) - y[1] = uint8(sigAlgo) - y = y[2:] - } - } - - y[0] = uint8(casLength >> 8) - y[1] = uint8(casLength) - y = y[2:] - for _, ca := range m.certificateAuthorities { - y[0] = uint8(len(ca) >> 8) - y[1] = uint8(len(ca)) - y = y[2:] - copy(y, ca) - y = y[len(ca):] - } - - m.raw = x - return m.raw, nil -} - -func (m *certificateRequestMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 5 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - numCertTypes := int(data[4]) - data = data[5:] - if numCertTypes == 0 || len(data) <= numCertTypes { - return false - } - - m.certificateTypes = make([]byte, numCertTypes) - if copy(m.certificateTypes, data) != numCertTypes { - return false - } - - data = data[numCertTypes:] - - if m.hasSignatureAlgorithm { - if len(data) < 2 { - return false - } - sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if sigAndHashLen&1 != 0 { - return false - } - if len(data) < int(sigAndHashLen) { - return false - } - numSigAlgos := sigAndHashLen / 2 - m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) - for i := range m.supportedSignatureAlgorithms { - m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) - data = data[2:] - } - } - - if len(data) < 2 { - return false - } - casLength := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if len(data) < int(casLength) { - return false - } - cas := make([]byte, casLength) - copy(cas, data) - data = data[casLength:] - - m.certificateAuthorities = nil - for len(cas) > 0 { - if len(cas) < 2 { - return false - } - caLen := uint16(cas[0])<<8 | uint16(cas[1]) - cas = cas[2:] - - if len(cas) < int(caLen) { - return false - } - - m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) - cas = cas[caLen:] - } - - return len(data) == 0 -} - -type certificateVerifyMsg struct { - raw []byte - hasSignatureAlgorithm bool // format change introduced in TLS 1.2 - signatureAlgorithm SignatureScheme - signature []byte -} - -func (m *certificateVerifyMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateVerify) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.hasSignatureAlgorithm { - b.AddUint16(uint16(m.signatureAlgorithm)) - } - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.signature) - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *certificateVerifyMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - if !s.Skip(4) { // message type and uint24 length field - return false - } - if m.hasSignatureAlgorithm { - if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { - return false - } - } - return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() -} - -type newSessionTicketMsg struct { - raw []byte - ticket []byte -} - -func (m *newSessionTicketMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - // See RFC 5077, Section 3.3. - ticketLen := len(m.ticket) - length := 2 + 4 + ticketLen - x := make([]byte, 4+length) - x[0] = typeNewSessionTicket - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[8] = uint8(ticketLen >> 8) - x[9] = uint8(ticketLen) - copy(x[10:], m.ticket) - - m.raw = x - - return m.raw, nil -} - -func (m *newSessionTicketMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 10 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - ticketLen := int(data[8])<<8 + int(data[9]) - if len(data)-10 != ticketLen { - return false - } - - m.ticket = data[10:] - - return true -} - -type helloRequestMsg struct { -} - -func (*helloRequestMsg) marshal() ([]byte, error) { - return []byte{typeHelloRequest, 0, 0, 0}, nil -} - -func (*helloRequestMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type transcriptHash interface { - Write([]byte) (int, error) -} - -// transcriptMsg is a helper used to marshal and hash messages which typically -// are not written to the wire, and as such aren't hashed during Conn.writeRecord. -func transcriptMsg(msg handshakeMessage, h transcriptHash) error { - data, err := msg.marshal() - if err != nil { - return err - } - h.Write(data) - return nil -} diff --git a/pkg/tls/handshake_messages_test.go b/pkg/tls/handshake_messages_test.go deleted file mode 100644 index 206e2fb02..000000000 --- a/pkg/tls/handshake_messages_test.go +++ /dev/null @@ -1,495 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "encoding/hex" - "math/rand" - "reflect" - "strings" - "testing" - "testing/quick" - "time" -) - -var tests = []any{ - &clientHelloMsg{}, - &serverHelloMsg{}, - &finishedMsg{}, - - &certificateMsg{}, - &certificateRequestMsg{}, - &certificateVerifyMsg{ - hasSignatureAlgorithm: true, - }, - &certificateStatusMsg{}, - &clientKeyExchangeMsg{}, - &newSessionTicketMsg{}, - &sessionState{}, - &sessionStateTLS13{}, - &encryptedExtensionsMsg{}, - &endOfEarlyDataMsg{}, - &keyUpdateMsg{}, - &newSessionTicketMsgTLS13{}, - &certificateRequestMsgTLS13{}, - &certificateMsgTLS13{}, -} - -func mustMarshal(t *testing.T, msg handshakeMessage) []byte { - t.Helper() - b, err := msg.marshal() - if err != nil { - t.Fatal(err) - } - return b -} - -func TestMarshalUnmarshal(t *testing.T) { - rand := rand.New(rand.NewSource(time.Now().UnixNano())) - - for i, iface := range tests { - ty := reflect.ValueOf(iface).Type() - - n := 100 - if testing.Short() { - n = 5 - } - for j := 0; j < n; j++ { - v, ok := quick.Value(ty, rand) - if !ok { - t.Errorf("#%d: failed to create value", i) - break - } - - m1 := v.Interface().(handshakeMessage) - marshaled := mustMarshal(t, m1) - m2 := iface.(handshakeMessage) - if !m2.unmarshal(marshaled) { - t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) - break - } - m2.marshal() // to fill any marshal cache in the message - - if !reflect.DeepEqual(m1, m2) { - t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) - break - } - - if i >= 3 { - // The first three message types (ClientHello, - // ServerHello and Finished) are allowed to - // have parsable prefixes because the extension - // data is optional and the length of the - // Finished varies across versions. - for j := 0; j < len(marshaled); j++ { - if m2.unmarshal(marshaled[0:j]) { - t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) - break - } - } - } - } - } -} - -func TestFuzz(t *testing.T) { - rand := rand.New(rand.NewSource(0)) - for _, iface := range tests { - m := iface.(handshakeMessage) - - for j := 0; j < 1000; j++ { - len := rand.Intn(100) - bytes := randomBytes(len, rand) - // This just looks for crashes due to bounds errors etc. - m.unmarshal(bytes) - } - } -} - -func randomBytes(n int, rand *rand.Rand) []byte { - r := make([]byte, n) - if _, err := rand.Read(r); err != nil { - panic("rand.Read failed: " + err.Error()) - } - return r -} - -func randomString(n int, rand *rand.Rand) string { - b := randomBytes(n, rand) - return string(b) -} - -func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &clientHelloMsg{} - m.vers = uint16(rand.Intn(65536)) - m.random = randomBytes(32, rand) - m.sessionId = randomBytes(rand.Intn(32), rand) - m.cipherSuites = make([]uint16, rand.Intn(63)+1) - for i := 0; i < len(m.cipherSuites); i++ { - cs := uint16(rand.Int31()) - if cs == scsvRenegotiation { - cs += 1 - } - m.cipherSuites[i] = cs - } - m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) - if rand.Intn(10) > 5 { - m.serverName = randomString(rand.Intn(255), rand) - for strings.HasSuffix(m.serverName, ".") { - m.serverName = m.serverName[:len(m.serverName)-1] - } - } - m.ocspStapling = rand.Intn(10) > 5 - m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) - m.supportedCurves = make([]CurveID, rand.Intn(5)+1) - for i := range m.supportedCurves { - m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) - } - if rand.Intn(10) > 5 { - m.ticketSupported = true - if rand.Intn(10) > 5 { - m.sessionTicket = randomBytes(rand.Intn(300), rand) - } else { - m.sessionTicket = make([]byte, 0) - } - } - if rand.Intn(10) > 5 { - m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - } - if rand.Intn(10) > 5 { - m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() - } - for i := 0; i < rand.Intn(5); i++ { - m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) - } - if rand.Intn(10) > 5 { - m.scts = true - } - if rand.Intn(10) > 5 { - m.secureRenegotiationSupported = true - m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) - } - for i := 0; i < rand.Intn(5); i++ { - m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) - } - if rand.Intn(10) > 5 { - m.cookie = randomBytes(rand.Intn(500)+1, rand) - } - for i := 0; i < rand.Intn(5); i++ { - var ks keyShare - ks.group = CurveID(rand.Intn(30000) + 1) - ks.data = randomBytes(rand.Intn(200)+1, rand) - m.keyShares = append(m.keyShares, ks) - } - switch rand.Intn(3) { - case 1: - m.pskModes = []uint8{pskModeDHE} - case 2: - m.pskModes = []uint8{pskModeDHE, pskModePlain} - } - for i := 0; i < rand.Intn(5); i++ { - var psk pskIdentity - psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) - psk.label = randomBytes(rand.Intn(500)+1, rand) - m.pskIdentities = append(m.pskIdentities, psk) - m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) - } - if rand.Intn(10) > 5 { - m.earlyData = true - } - - return reflect.ValueOf(m) -} - -func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &serverHelloMsg{} - m.vers = uint16(rand.Intn(65536)) - m.random = randomBytes(32, rand) - m.sessionId = randomBytes(rand.Intn(32), rand) - m.cipherSuite = uint16(rand.Int31()) - m.compressionMethod = uint8(rand.Intn(256)) - m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) - - if rand.Intn(10) > 5 { - m.ocspStapling = true - } - if rand.Intn(10) > 5 { - m.ticketSupported = true - } - if rand.Intn(10) > 5 { - m.alpnProtocol = randomString(rand.Intn(32)+1, rand) - } - - for i := 0; i < rand.Intn(4); i++ { - m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) - } - - if rand.Intn(10) > 5 { - m.secureRenegotiationSupported = true - m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) - } - if rand.Intn(10) > 5 { - m.supportedVersion = uint16(rand.Intn(0xffff) + 1) - } - if rand.Intn(10) > 5 { - m.cookie = randomBytes(rand.Intn(500)+1, rand) - } - if rand.Intn(10) > 5 { - for i := 0; i < rand.Intn(5); i++ { - m.serverShare.group = CurveID(rand.Intn(30000) + 1) - m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) - } - } else if rand.Intn(10) > 5 { - m.selectedGroup = CurveID(rand.Intn(30000) + 1) - } - if rand.Intn(10) > 5 { - m.selectedIdentityPresent = true - m.selectedIdentity = uint16(rand.Intn(0xffff)) - } - - return reflect.ValueOf(m) -} - -func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &encryptedExtensionsMsg{} - - if rand.Intn(10) > 5 { - m.alpnProtocol = randomString(rand.Intn(32)+1, rand) - } - - return reflect.ValueOf(m) -} - -func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateMsg{} - numCerts := rand.Intn(20) - m.certificates = make([][]byte, numCerts) - for i := 0; i < numCerts; i++ { - m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) - } - return reflect.ValueOf(m) -} - -func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateRequestMsg{} - m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) - for i := 0; i < rand.Intn(100); i++ { - m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) - } - return reflect.ValueOf(m) -} - -func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateVerifyMsg{} - m.hasSignatureAlgorithm = true - m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) - m.signature = randomBytes(rand.Intn(15)+1, rand) - return reflect.ValueOf(m) -} - -func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateStatusMsg{} - m.response = randomBytes(rand.Intn(10)+1, rand) - return reflect.ValueOf(m) -} - -func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &clientKeyExchangeMsg{} - m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) - return reflect.ValueOf(m) -} - -func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &finishedMsg{} - m.verifyData = randomBytes(12, rand) - return reflect.ValueOf(m) -} - -func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &newSessionTicketMsg{} - m.ticket = randomBytes(rand.Intn(4), rand) - return reflect.ValueOf(m) -} - -func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { - s := &sessionState{} - s.vers = uint16(rand.Intn(10000)) - s.cipherSuite = uint16(rand.Intn(10000)) - s.masterSecret = randomBytes(rand.Intn(100)+1, rand) - s.createdAt = uint64(rand.Int63()) - for i := 0; i < rand.Intn(20); i++ { - s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) - } - return reflect.ValueOf(s) -} - -func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { - s := &sessionStateTLS13{} - s.cipherSuite = uint16(rand.Intn(10000)) - s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) - s.createdAt = uint64(rand.Int63()) - for i := 0; i < rand.Intn(2)+1; i++ { - s.certificate.Certificate = append( - s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) - } - if rand.Intn(10) > 5 { - s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) - } - if rand.Intn(10) > 5 { - for i := 0; i < rand.Intn(2)+1; i++ { - s.certificate.SignedCertificateTimestamps = append( - s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) - } - } - return reflect.ValueOf(s) -} - -func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &endOfEarlyDataMsg{} - return reflect.ValueOf(m) -} - -func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &keyUpdateMsg{} - m.updateRequested = rand.Intn(10) > 5 - return reflect.ValueOf(m) -} - -func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { - m := &newSessionTicketMsgTLS13{} - m.lifetime = uint32(rand.Intn(500000)) - m.ageAdd = uint32(rand.Intn(500000)) - m.nonce = randomBytes(rand.Intn(100), rand) - m.label = randomBytes(rand.Intn(1000), rand) - if rand.Intn(10) > 5 { - m.maxEarlyData = uint32(rand.Intn(500000)) - } - return reflect.ValueOf(m) -} - -func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateRequestMsgTLS13{} - if rand.Intn(10) > 5 { - m.ocspStapling = true - } - if rand.Intn(10) > 5 { - m.scts = true - } - if rand.Intn(10) > 5 { - m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - } - if rand.Intn(10) > 5 { - m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() - } - if rand.Intn(10) > 5 { - m.certificateAuthorities = make([][]byte, 3) - for i := 0; i < 3; i++ { - m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) - } - } - return reflect.ValueOf(m) -} - -func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateMsgTLS13{} - for i := 0; i < rand.Intn(2)+1; i++ { - m.certificate.Certificate = append( - m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) - } - if rand.Intn(10) > 5 { - m.ocspStapling = true - m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) - } - if rand.Intn(10) > 5 { - m.scts = true - for i := 0; i < rand.Intn(2)+1; i++ { - m.certificate.SignedCertificateTimestamps = append( - m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) - } - } - return reflect.ValueOf(m) -} - -func TestRejectEmptySCTList(t *testing.T) { - // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. - - var random [32]byte - sct := []byte{0x42, 0x42, 0x42, 0x42} - serverHello := &serverHelloMsg{ - vers: VersionTLS12, - random: random[:], - scts: [][]byte{sct}, - } - serverHelloBytes := mustMarshal(t, serverHello) - - var serverHelloCopy serverHelloMsg - if !serverHelloCopy.unmarshal(serverHelloBytes) { - t.Fatal("Failed to unmarshal initial message") - } - - // Change serverHelloBytes so that the SCT list is empty - i := bytes.Index(serverHelloBytes, sct) - if i < 0 { - t.Fatal("Cannot find SCT in ServerHello") - } - - var serverHelloEmptySCT []byte - serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) - // Append the extension length and SCT list length for an empty list. - serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) - serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) - - // Update the handshake message length. - serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) - serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) - serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) - - // Update the extensions length - serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) - serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) - - if serverHelloCopy.unmarshal(serverHelloEmptySCT) { - t.Fatal("Unmarshaled ServerHello with empty SCT list") - } -} - -func TestRejectEmptySCT(t *testing.T) { - // Not only must the SCT list be non-empty, but the SCT elements must - // not be zero length. - - var random [32]byte - serverHello := &serverHelloMsg{ - vers: VersionTLS12, - random: random[:], - scts: [][]byte{nil}, - } - serverHelloBytes := mustMarshal(t, serverHello) - - var serverHelloCopy serverHelloMsg - if serverHelloCopy.unmarshal(serverHelloBytes) { - t.Fatal("Unmarshaled ServerHello with zero-length SCT") - } -} - -func TestRejectDuplicateExtensions(t *testing.T) { - clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f") - if err != nil { - t.Fatalf("failed to decode test ClientHello: %s", err) - } - var clientHelloCopy clientHelloMsg - if clientHelloCopy.unmarshal(clientHelloBytes) { - t.Error("Unmarshaled ClientHello with duplicate extensions") - } - - serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000") - if err != nil { - t.Fatalf("failed to decode test ServerHello: %s", err) - } - var serverHelloCopy serverHelloMsg - if serverHelloCopy.unmarshal(serverHelloBytes) { - t.Fatal("Unmarshaled ServerHello with duplicate extensions") - } -} diff --git a/pkg/tls/handshake_server.go b/pkg/tls/handshake_server.go deleted file mode 100644 index 0198aca3f..000000000 --- a/pkg/tls/handshake_server.go +++ /dev/null @@ -1,934 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "strconv" - "time" -) - -// serverHandshakeState contains details of a server handshake in progress. -// It's discarded once the handshake has completed. -type serverHandshakeState struct { - c *Conn - ctx context.Context - clientHello *clientHelloMsg - hello *serverHelloMsg - suite *cipherSuite - ecdheOk bool - ecSignOk bool - rsaDecryptOk bool - rsaSignOk bool - sessionState *sessionState - finishedHash finishedHash - masterSecret []byte - cert *Certificate - - keyAgreement keyAgreement - certReq *certificateRequestMsg -} - -// serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake(ctx context.Context) error { - // If this is the first server handshake, we generate a random key to - // encrypt the tickets with. - //gnet不能进行阻塞二次读取,所以会分几条消息重复执行此方法,status也会分很多个状态 - if c.hs == nil { - //首次执行要初始化对象 - clientHello, err := c.readClientHello(ctx) - if err != nil { - return err - } - - if c.vers == VersionTLS13 { - c.hs = &serverHandshakeStateTLS13{ - c: c, - ctx: ctx, - clientHello: clientHello, - } - - } else { - c.hs = &serverHandshakeState{ - c: c, - ctx: ctx, - clientHello: clientHello, - } - } - } - return c.hs.handshake() -} - -func (hs *serverHandshakeState) handshake() error { - c := hs.c - if c.handshakeStatus == 0 { - if err := hs.processClientHello(); err != nil { - return err - } - - // For an overview of TLS handshaking, see RFC 5246, Section 7.3. - c.buffering = true - } - - if hs.checkForResumption() { - switch c.handshakeStatus { - case 0: - // The client has included a session ticket and so we do an abbreviated handshake. - c.didResume = true - if err := hs.doResumeHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(c.serverFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.handshakeStatus = 1 - return nil - case 1: - c.clientFinishedIsFirst = false - if err := hs.readFinished(nil); err != nil { - return err - } - - default: - return errors.New("错误的status状态" + strconv.Itoa(int(c.handshakeStatus))) - } - } else { - // The client didn't include a session ticket, or it wasn't - // valid so we do a full handshake. - switch c.handshakeStatus { - case 0: - if err := hs.pickCipherSuite(); err != nil { - return err - } - if err := hs.doFullHandshakeStep1(); err != nil { - return err - } - c.handshakeStatus = 3 - return nil - case 3: - if err := hs.doFullHandshakeStep2(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - c.handshakeStatus = 4 - if c.rawInput.Len() < 5 { - return nil - } - fallthrough - case 4: - if err := hs.readFinished(c.clientFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = true - c.buffering = true - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(nil); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } - - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - c.handshakeStatus = 255 - // Enable kernel TLS if possible - if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { - return err - } - - return nil -} - -// readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - // clientHelloMsg is included in the transcript, but we haven't initialized - // it yet. The respective handshake functions will record it themselves. - msg, err := c.readHandshake(nil) - if err != nil { - return nil, err - } - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return nil, unexpectedMessageError(clientHello, msg) - } - - var configForClient *Config - originalConfig := c.config - if c.config.GetConfigForClient != nil { - chi := clientHelloInfo(ctx, c, clientHello) - if configForClient, err = c.config.GetConfigForClient(chi); err != nil { - c.sendAlert(alertInternalError) - return nil, err - } else if configForClient != nil { - c.config = configForClient - } - } - c.ticketKeys = originalConfig.ticketKeys(configForClient) - - clientVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - clientVersions = supportedVersionsFromMax(clientHello.vers) - } - c.vers, ok = c.config.mutualVersion(roleServer, clientVersions) - if !ok { - c.sendAlert(alertProtocolVersion) - return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) - } - c.haveVers = true - c.in.version = c.vers - c.out.version = c.vers - - return clientHello, nil -} - -func (hs *serverHandshakeState) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.hello.vers = c.vers - - foundCompression := false - // We only support null compression, so check that the client offered it. - for _, compression := range hs.clientHello.compressionMethods { - if compression == compressionNone { - foundCompression = true - break - } - } - - if !foundCompression { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client does not support uncompressed connections") - } - - hs.hello.random = make([]byte, 32) - serverRandom := hs.hello.random - // Downgrade protection canaries. See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion(roleServer) - if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { - if c.vers == VersionTLS12 { - copy(serverRandom[24:], downgradeCanaryTLS12) - } else { - copy(serverRandom[24:], downgradeCanaryTLS11) - } - serverRandom = serverRandom[:24] - } - _, err := io.ReadFull(c.config.rand(), serverRandom) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported - hs.hello.compressionMethod = compressionNone - if len(hs.clientHello.serverName) > 0 { - c.serverName = hs.clientHello.serverName - } - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) - if err != nil { - c.sendAlert(alertNoApplicationProtocol) - return err - } - hs.hello.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - if hs.clientHello.scts { - hs.hello.scts = hs.cert.SignedCertificateTimestamps - } - - hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) - - if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 { - // Although omitting the ec_point_formats extension is permitted, some - // old OpenSSL version will refuse to handshake if not present. - // - // Per RFC 4492, section 5.1.2, implementations MUST support the - // uncompressed point format. See golang.org/issue/31943. - hs.hello.supportedPoints = []uint8{pointFormatUncompressed} - } - - if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { - switch priv.Public().(type) { - case *ecdsa.PublicKey: - hs.ecSignOk = true - case ed25519.PublicKey: - hs.ecSignOk = true - case *rsa.PublicKey: - hs.rsaSignOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) - } - } - if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { - switch priv.Public().(type) { - case *rsa.PublicKey: - hs.rsaDecryptOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) - } - } - - return nil -} - -// negotiateALPN picks a shared ALPN protocol that both sides support in server -// preference order. If ALPN is not configured or the peer doesn't support it, -// it returns "" and no error. -func negotiateALPN(serverProtos, clientProtos []string) (string, error) { - if len(serverProtos) == 0 || len(clientProtos) == 0 { - return "", nil - } - var http11fallback bool - for _, s := range serverProtos { - for _, c := range clientProtos { - if s == c { - return s, nil - } - if s == "h2" && c == "http/1.1" { - http11fallback = true - } - } - } - // As a special case, let http/1.1 clients connect to h2 servers as if they - // didn't support ALPN. We used not to enforce protocol overlap, so over - // time a number of HTTP servers were configured with only "h2", but - // expected to accept connections from "http/1.1" clients. See Issue 46310. - if http11fallback { - return "", nil - } - return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) -} - -// supportsECDHE returns whether ECDHE key exchanges can be used with this -// pre-TLS 1.3 client. -func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool { - supportsCurve := false - for _, curve := range supportedCurves { - if c.supportsCurve(curve) { - supportsCurve = true - break - } - } - - supportsPointFormat := false - for _, pointFormat := range supportedPoints { - if pointFormat == pointFormatUncompressed { - supportsPointFormat = true - break - } - } - // Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is - // missing, uncompressed points are supported. If supportedPoints is empty, - // the extension must be missing, as an empty extension body is rejected by - // the parser. See https://go.dev/issue/49126. - if len(supportedPoints) == 0 { - supportsPointFormat = true - } - - return supportsCurve && supportsPointFormat -} - -func (hs *serverHandshakeState) pickCipherSuite() error { - c := hs.c - - preferenceOrder := cipherSuitesPreferenceOrder - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceOrder = cipherSuitesPreferenceOrderNoAES - } - - configCipherSuites := c.config.cipherSuites() - preferenceList := make([]uint16, 0, len(configCipherSuites)) - for _, suiteID := range preferenceOrder { - for _, id := range configCipherSuites { - if id == suiteID { - preferenceList = append(preferenceList, id) - break - } - } - } - - hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // The client is doing a fallback connection. See RFC 7507. - if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - return nil -} - -func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - if !hs.ecdheOk { - return false - } - if c.flags&suiteECSign != 0 { - if !hs.ecSignOk { - return false - } - } else if !hs.rsaSignOk { - return false - } - } else if !hs.rsaDecryptOk { - return false - } - if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true -} - -// checkForResumption reports whether we should perform resumption on this connection. -func (hs *serverHandshakeState) checkForResumption() bool { - c := hs.c - - if c.config.SessionTicketsDisabled { - return false - } - - plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) - if plaintext == nil { - return false - } - hs.sessionState = &sessionState{usedOldKey: usedOldKey} - ok := hs.sessionState.unmarshal(plaintext) - if !ok { - return false - } - - createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - return false - } - - // Never resume a session for a different TLS version. - if c.vers != hs.sessionState.vers { - return false - } - - cipherSuiteOk := false - // Check that the client is still offering the ciphersuite in the session. - for _, id := range hs.clientHello.cipherSuites { - if id == hs.sessionState.cipherSuite { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return false - } - - // Check that we also support the ciphersuite from the session. - hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, - c.config.cipherSuites(), hs.cipherSuiteOk) - if hs.suite == nil { - return false - } - - sessionHasClientCerts := len(hs.sessionState.certificates) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - return false - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - return false - } - - return true -} - -func (hs *serverHandshakeState) doResumeHandshake() error { - c := hs.c - - hs.hello.cipherSuite = hs.suite.id - c.cipherSuite = hs.suite.id - // We echo the client's session ID in the ServerHello to let it know - // that we're doing a resumption. - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.ticketSupported = hs.sessionState.usedOldKey - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - hs.finishedHash.discardHandshakeBuffer() - if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { - return err - } - if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { - return err - } - - if err := c.processCertsFromClient(Certificate{ - Certificate: hs.sessionState.certificates, - }); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - hs.masterSecret = hs.sessionState.masterSecret - - return nil -} - -func (hs *serverHandshakeState) doFullHandshakeStep1() error { - c := hs.c - - if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { - hs.hello.ocspStapling = true - } - - hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled - hs.hello.cipherSuite = hs.suite.id - - hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) - if c.config.ClientAuth == NoClientCert { - // No need to keep a full record of the handshake if client - // certificates won't be used. - hs.finishedHash.discardHandshakeBuffer() - } - if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { - return err - } - if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { - return err - } - - certMsg := new(certificateMsg) - certMsg.certificates = hs.cert.Certificate - if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { - return err - } - - if hs.hello.ocspStapling { - certStatus := new(certificateStatusMsg) - certStatus.response = hs.cert.OCSPStaple - if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { - return err - } - } - - hs.keyAgreement = hs.suite.ka(c.vers) - skx, err := hs.keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - if skx != nil { - if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { - return err - } - } - - if c.config.ClientAuth >= RequestClientCert { - // Request a client certificate - hs.certReq = new(certificateRequestMsg) - hs.certReq.certificateTypes = []byte{ - byte(certTypeRSASign), - byte(certTypeECDSASign), - } - if c.vers >= VersionTLS12 { - hs.certReq.hasSignatureAlgorithm = true - hs.certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - } - - // An empty list of certificateAuthorities signals to - // the client that it may send any certificate in response - // to our request. When we know the CAs we trust, then - // we can send them down, so that the client can choose - // an appropriate certificate to give to us. - if c.config.ClientCAs != nil { - hs.certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - if _, err := hs.c.writeHandshakeRecord(hs.certReq, &hs.finishedHash); err != nil { - return err - } - } - - helloDone := new(serverHelloDoneMsg) - if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { - return err - } - - _, err = c.flush() - return err - -} -func (hs *serverHandshakeState) doFullHandshakeStep2() error { - c := hs.c - - var pub crypto.PublicKey // public key for client auth, if any - - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - - // If we requested a client certificate, then the client must send a - // certificate message, even if it's empty. - if c.config.ClientAuth >= RequestClientCert { - certMsg, ok := msg.(*certificateMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - - if err := c.processCertsFromClient(Certificate{ - Certificate: certMsg.certificates, - }); err != nil { - return err - } - if len(certMsg.certificates) != 0 { - pub = c.peerCertificates[0].PublicKey - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - // Get client key exchange - ckx, ok := msg.(*clientKeyExchangeMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(ckx, msg) - } - - preMasterSecret, err := hs.keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - - // If we received a client cert in response to our certificate request message, - // the client will send us a certificateVerifyMsg immediately after the - // clientKeyExchangeMsg. This message is a digest of all preceding - // handshake-layer messages that is signed using the private key corresponding - // to the client's certificate. This allows us to verify that the client is in - // possession of the private key of the certificate. - if len(c.peerCertificates) > 0 { - // certificateVerifyMsg is included in the transcript, but not until - // after we verify the handshake signature, since the state before - // this message was sent is used. - msg, err = c.readHandshake(nil) - if err != nil { - return err - } - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, hs.certReq.supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) - if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { - return err - } - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *serverHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - - var clientCipher, serverCipher interface{} - var clientHash, serverHash hash.Hash - - if hs.suite.aead == nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.key, c.in.iv = clientKey, clientIV - c.out.key, c.out.iv = serverKey, serverIV - c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) - c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) - - return nil -} - -func (hs *serverHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - // finishedMsg is included in the transcript, but not until after we - // check the client version, since the state before this message was - // sent is used during verification. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - clientFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientFinished, msg) - } - - verify := hs.finishedHash.clientSum(hs.masterSecret) - if len(verify) != len(clientFinished.verifyData) || - subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client's Finished message is incorrect") - } - - if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { - return err - } - - copy(out, verify) - return nil -} - -func (hs *serverHandshakeState) sendSessionTicket() error { - // ticketSupported is set in a resumption handshake if the - // ticket from the client was encrypted with an old session - // ticket key and thus a refreshed ticket should be sent. - if !hs.hello.ticketSupported { - return nil - } - - c := hs.c - m := new(newSessionTicketMsg) - - createdAt := uint64(c.config.time().Unix()) - if hs.sessionState != nil { - // If this is re-wrapping an old key, then keep - // the original time it was created. - createdAt = hs.sessionState.createdAt - } - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionState{ - vers: c.vers, - cipherSuite: hs.suite.id, - createdAt: createdAt, - masterSecret: hs.masterSecret, - certificates: certsFromClient, - } - stateBytes, err := state.marshal() - if err != nil { - return err - } - m.ticket, err = c.encryptTicket(stateBytes) - if err != nil { - return err - } - - if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if err := c.writeChangeCipherRecord(); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { - return err - } - - copy(out, finished.verifyData) - - return nil -} - -// processCertsFromClient takes a chain of client certificates either from a -// Certificates message or from a sessionState and verifies them. It returns -// the public key of the leaf certificate. -func (c *Conn) processCertsFromClient(certificate Certificate) error { - certificates := certificate.Certificate - certs := make([]*x509.Certificate, len(certificates)) - var err error - for i, asn1Data := range certificates { - if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse client certificate: " + err.Error()) - } - } - - if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: client didn't provide a certificate") - } - - if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { - opts := x509.VerifyOptions{ - Roots: c.config.ClientCAs, - CurrentTime: c.config.time(), - Intermediates: x509.NewCertPool(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - - chains, err := certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} - } - - c.verifiedChains = chains - } - - c.peerCertificates = certs - c.ocspResponse = certificate.OCSPStaple - c.scts = certificate.SignedCertificateTimestamps - - if len(certs) > 0 { - switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) - } - } - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { - supportedVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - supportedVersions = supportedVersionsFromMax(clientHello.vers) - } - - return &ClientHelloInfo{ - CipherSuites: clientHello.cipherSuites, - ServerName: clientHello.serverName, - SupportedCurves: clientHello.supportedCurves, - SupportedPoints: clientHello.supportedPoints, - SignatureSchemes: clientHello.supportedSignatureAlgorithms, - SupportedProtos: clientHello.alpnProtocols, - SupportedVersions: supportedVersions, - Conn: c.conn, - config: c.config, - } -} diff --git a/pkg/tls/handshake_server_tls13.go b/pkg/tls/handshake_server_tls13.go deleted file mode 100644 index 0d838824e..000000000 --- a/pkg/tls/handshake_server_tls13.go +++ /dev/null @@ -1,902 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "context" - "crypto" - "crypto/hmac" - "crypto/rsa" - "encoding/binary" - "errors" - "hash" - "io" - "time" -) - -// maxClientPSKIdentities is the number of client PSK identities the server will -// attempt to validate. It will ignore the rest not to let cheap ClientHello -// messages cause too much work in session ticket decryption attempts. -const maxClientPSKIdentities = 5 - -type serverHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - clientHello *clientHelloMsg - hello *serverHelloMsg - sentDummyCCS bool - usingPSK bool - suite *cipherSuiteTLS13 - cert *Certificate - sigAlg SignatureScheme - earlySecret []byte - sharedKey []byte - handshakeSecret []byte - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 - transcript hash.Hash - clientFinished []byte -} - -func (hs *serverHandshakeStateTLS13) handshake() error { - c := hs.c - - if needFIPS() { - return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") - } - - switch c.handshakeStatus { - case 0: - // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. - if err := hs.processClientHello(); err != nil { - - return err - } - if err := hs.checkForResumption(); err != nil { - return err - } - if err := hs.pickCertificate(); err != nil { - return err - } - c.buffering = true - if err := hs.sendServerParameters(); err != nil { - return err - } - if err := hs.sendServerCertificate(); err != nil { - return err - } - if err := hs.sendServerFinished(); err != nil { - return err - } - // Note that at this point we could start sending application data without - // waiting for the client's second flight, but the application might not - // expect the lack of replay protection of the ClientHello parameters. - if _, err := c.flush(); err != nil { - return err - } - c.handshakeStatus = 1 - case 1: - if err := hs.readClientCertificate(); err != nil { - return err - } - if err := hs.readClientFinished(); err != nil { - return err - } - c.handshakeStatus = 255 - // Enable kernel TLS if possible - if err := c.enableKernelTLS(c.cipherSuite, c.in.key, c.out.key, c.in.iv, c.out.iv); err != nil { - return err - } - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - - // TLS 1.3 froze the ServerHello.legacy_version field, and uses - // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. - hs.hello.vers = VersionTLS12 - hs.hello.supportedVersion = c.vers - - if len(hs.clientHello.supportedVersions) == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") - } - - // Abort if the client is doing a fallback and landing lower than what we - // support. See RFC 7507, which however does not specify the interaction - // with supported_versions. The only difference is that with - // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] - // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, - // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to - // TLS 1.2, because a TLS 1.3 server would abort here. The situation before - // supported_versions was not better because there was just no way to do a - // TLS 1.4 handshake without risking the server selecting TLS 1.3. - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // Use c.vers instead of max(supported_versions) because an attacker - // could defeat this by adding an arbitrary high version otherwise. - if c.vers < c.config.maxSupportedVersion(roleServer) { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - if len(hs.clientHello.compressionMethods) != 1 || - hs.clientHello.compressionMethods[0] != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: TLS 1.3 client supports illegal compression methods") - } - - hs.hello.random = make([]byte, 32) - if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - if hs.clientHello.earlyData { - // See RFC 8446, Section 4.2.10 for the complicated behavior required - // here. The scenario is that a different server at our address offered - // to accept early data in the past, which we can't handle. For now, all - // 0-RTT enabled session tickets need to expire before a Go server can - // replace a server or join a pool. That's the same requirement that - // applies to mixing or replacing with any TLS 1.2 server. - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: client sent unexpected early data") - } - - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.compressionMethod = compressionNone - - preferenceList := defaultCipherSuitesTLS13 - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = defaultCipherSuitesTLS13NoAES - } - for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) - if hs.suite != nil { - break - } - } - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - hs.hello.cipherSuite = hs.suite.id - hs.transcript = hs.suite.hash.New() - - // Pick the ECDHE group in server preference order, but give priority to - // groups with a key share, to avoid a HelloRetryRequest round-trip. - var selectedGroup CurveID - var clientKeyShare *keyShare -GroupSelection: - for _, preferredGroup := range c.config.curvePreferences() { - for _, ks := range hs.clientHello.keyShares { - if ks.group == preferredGroup { - selectedGroup = ks.group - clientKeyShare = &ks - break GroupSelection - } - } - if selectedGroup != 0 { - continue - } - for _, group := range hs.clientHello.supportedCurves { - if group == preferredGroup { - selectedGroup = group - break - } - } - } - if selectedGroup == 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no ECDHE curve supported by both client and server") - } - if clientKeyShare == nil { - if err := hs.doHelloRetryRequest(selectedGroup); err != nil { - return err - } - clientKeyShare = &hs.clientHello.keyShares[0] - } - - if _, ok := curveForCurveID(selectedGroup); !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err := generateECDHEKey(c.config.rand(), selectedGroup) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} - peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - hs.sharedKey, err = key.ECDH(peerKey) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - - c.serverName = hs.clientHello.serverName - return nil -} - -func (hs *serverHandshakeStateTLS13) checkForResumption() error { - c := hs.c - - if c.config.SessionTicketsDisabled { - return nil - } - - modeOK := false - for _, mode := range hs.clientHello.pskModes { - if mode == pskModeDHE { - modeOK = true - break - } - } - if !modeOK { - return nil - } - - if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid or missing PSK binders") - } - if len(hs.clientHello.pskIdentities) == 0 { - return nil - } - - for i, identity := range hs.clientHello.pskIdentities { - if i >= maxClientPSKIdentities { - break - } - - plaintext, _ := c.decryptTicket(identity.label) - if plaintext == nil { - continue - } - sessionState := new(sessionStateTLS13) - if ok := sessionState.unmarshal(plaintext); !ok { - continue - } - - createdAt := time.Unix(int64(sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - continue - } - - // We don't check the obfuscated ticket age because it's affected by - // clock skew and it's only a freshness signal useful for shrinking the - // window for replay attacks, which don't affect us as we don't do 0-RTT. - - pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) - if pskSuite == nil || pskSuite.hash != hs.suite.hash { - continue - } - - // PSK connections don't re-establish client certificates, but carry - // them over in the session ticket. Ensure the presence of client certs - // in the ticket is consistent with the configured requirements. - sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - continue - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - continue - } - - psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", - nil, hs.suite.hash.Size()) - hs.earlySecret = hs.suite.extract(psk, nil) - binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) - // Clone the transcript in case a HelloRetryRequest was recorded. - transcript := cloneHash(hs.transcript, hs.suite.hash) - if transcript == nil { - c.sendAlert(alertInternalError) - return errors.New("tls: internal error: failed to clone hash") - } - clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() - if err != nil { - c.sendAlert(alertInternalError) - return err - } - transcript.Write(clientHelloBytes) - pskBinder := hs.suite.finishedHash(binderKey, transcript) - if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid PSK binder") - } - - c.didResume = true - if err := c.processCertsFromClient(sessionState.certificate); err != nil { - return err - } - - hs.hello.selectedIdentityPresent = true - hs.hello.selectedIdentity = uint16(i) - hs.usingPSK = true - return nil - } - - return nil -} - -// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler -// interfaces implemented by standard library hashes to clone the state of in -// to a new instance of h. It returns nil if the operation fails. -func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { - // Recreate the interface to avoid importing encoding. - type binaryMarshaler interface { - MarshalBinary() (data []byte, err error) - UnmarshalBinary(data []byte) error - } - marshaler, ok := in.(binaryMarshaler) - if !ok { - return nil - } - state, err := marshaler.MarshalBinary() - if err != nil { - return nil - } - out := h.New() - unmarshaler, ok := out.(binaryMarshaler) - if !ok { - return nil - } - if err := unmarshaler.UnmarshalBinary(state); err != nil { - return nil - } - return out -} - -func (hs *serverHandshakeStateTLS13) pickCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. - if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { - return c.sendAlert(alertMissingExtension) - } - - certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) - if err != nil { - // getCertificate returned a certificate that is unsupported or - // incompatible with the client's signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - hs.cert = certificate - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - return hs.c.writeChangeCipherRecord() -} - -func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. See RFC 8446, Section 4.4.1. - if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { - return err - } - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - - helloRetryRequest := &serverHelloMsg{ - vers: hs.hello.vers, - random: helloRetryRequestRandom, - sessionId: hs.hello.sessionId, - cipherSuite: hs.hello.cipherSuite, - compressionMethod: hs.hello.compressionMethod, - supportedVersion: hs.hello.supportedVersion, - selectedGroup: selectedGroup, - } - - if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - // clientHelloMsg is not included in the transcript. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientHello, msg) - } - - if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client sent invalid key share in second ClientHello") - } - - if clientHello.earlyData { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client indicated early data in second ClientHello") - } - - if illegalClientHelloChange(clientHello, hs.clientHello) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client illegally modified second ClientHello") - } - - hs.clientHello = clientHello - return nil -} - -// illegalClientHelloChange reports whether the two ClientHello messages are -// different, with the exception of the changes allowed before and after a -// HelloRetryRequest. See RFC 8446, Section 4.1.2. -func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { - if len(ch.supportedVersions) != len(ch1.supportedVersions) || - len(ch.cipherSuites) != len(ch1.cipherSuites) || - len(ch.supportedCurves) != len(ch1.supportedCurves) || - len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || - len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || - len(ch.alpnProtocols) != len(ch1.alpnProtocols) { - return true - } - for i := range ch.supportedVersions { - if ch.supportedVersions[i] != ch1.supportedVersions[i] { - return true - } - } - for i := range ch.cipherSuites { - if ch.cipherSuites[i] != ch1.cipherSuites[i] { - return true - } - } - for i := range ch.supportedCurves { - if ch.supportedCurves[i] != ch1.supportedCurves[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithms { - if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithmsCert { - if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { - return true - } - } - for i := range ch.alpnProtocols { - if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { - return true - } - } - return ch.vers != ch1.vers || - !bytes.Equal(ch.random, ch1.random) || - !bytes.Equal(ch.sessionId, ch1.sessionId) || - !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || - ch.serverName != ch1.serverName || - ch.ocspStapling != ch1.ocspStapling || - !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || - ch.ticketSupported != ch1.ticketSupported || - !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || - ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || - !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || - ch.scts != ch1.scts || - !bytes.Equal(ch.cookie, ch1.cookie) || - !bytes.Equal(ch.pskModes, ch1.pskModes) -} - -func (hs *serverHandshakeStateTLS13) sendServerParameters() error { - c := hs.c - - if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { - return err - } - if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - earlySecret := hs.earlySecret - if earlySecret == nil { - earlySecret = hs.suite.extract(nil, nil) - } - hs.handshakeSecret = hs.suite.extract(hs.sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, clientSecret) - serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - encryptedExtensions := new(encryptedExtensionsMsg) - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) - if err != nil { - c.sendAlert(alertNoApplicationProtocol) - return err - } - encryptedExtensions.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) requestClientCert() bool { - return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK -} - -func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - if hs.requestClientCert() { - // Request a client certificate - certReq := new(certificateRequestMsgTLS13) - certReq.ocspStapling = true - certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - - if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { - return err - } - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *hs.cert - certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - - if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { - return err - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - certVerifyMsg.signatureAlgorithm = hs.sigAlg - - sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - public := hs.cert.PrivateKey.(crypto.Signer).Public() - if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && - rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS - c.sendAlert(alertHandshakeFailure) - } else { - c.sendAlert(alertInternalError) - } - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) sendServerFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { - return err - } - - // Derive secrets that take context through the server Finished. - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - // If we did not request client certificates, at this point we can - // precompute the client finished and roll the transcript forward to send - // session tickets in our first flight. - if !hs.requestClientCert() { - if err := hs.sendSessionTickets(); err != nil { - return err - } - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { - if hs.c.config.SessionTicketsDisabled { - return false - } - - // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. - for _, pskMode := range hs.clientHello.pskModes { - if pskMode == pskModeDHE { - return true - } - } - return false -} - -func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { - c := hs.c - - hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - finishedMsg := &finishedMsg{ - verifyData: hs.clientFinished, - } - if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { - return err - } - - if !hs.shouldSendSessionTickets() { - return nil - } - - resumptionSecret := hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - - m := new(newSessionTicketMsgTLS13) - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionStateTLS13{ - cipherSuite: hs.suite.id, - createdAt: uint64(c.config.time().Unix()), - resumptionSecret: resumptionSecret, - certificate: Certificate{ - Certificate: certsFromClient, - OCSPStaple: c.ocspResponse, - SignedCertificateTimestamps: c.scts, - }, - } - stateBytes, err := state.marshal() - if err != nil { - c.sendAlert(alertInternalError) - return err - } - m.label, err = c.encryptTicket(stateBytes) - if err != nil { - return err - } - m.lifetime = uint32(maxSessionTicketLifetime / time.Second) - - // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 - // The value is not stored anywhere; we never need to check the ticket age - // because 0-RTT is not supported. - ageAdd := make([]byte, 4) - _, err = hs.c.config.rand().Read(ageAdd) - if err != nil { - return err - } - m.ageAdd = binary.LittleEndian.Uint32(ageAdd) - - // ticket_nonce, which must be unique per connection, is always left at - // zero because we only ever send one ticket per connection. - - if _, err := c.writeHandshakeRecord(m, nil); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientCertificate() error { - c := hs.c - - if !hs.requestClientCert() { - // Make sure the connection is still being verified whether or not - // the server requested a client certificate. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - // If we requested a client certificate, then the client must send a - // certificate message. If it's empty, no CertificateVerify is sent. - - msg, err := c.readHandshake(hs.transcript) - if err != nil { - return err - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - - if err := c.processCertsFromClient(certMsg.certificate); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if len(certMsg.certificate.Certificate) != 0 { - // certificateVerifyMsg is included in the transcript, but not until - // after we verify the handshake signature, since the state before - // this message was sent is used. - msg, err = c.readHandshake(nil) - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - if err := transcriptMsg(certVerify, hs.transcript); err != nil { - return err - } - } - - // If we waited until the client certificates to send session tickets, we - // are ready to do it now. - if err := hs.sendSessionTickets(); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientFinished() error { - c := hs.c - - // finishedMsg is not included in the transcript. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - if !hmac.Equal(hs.clientFinished, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid client finished hash") - } - - c.in.setTrafficSecret(hs.suite, hs.trafficSecret) - - return nil -} diff --git a/pkg/tls/handshake_test.go b/pkg/tls/handshake_test.go deleted file mode 100644 index bacc8b7d4..000000000 --- a/pkg/tls/handshake_test.go +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bufio" - "crypto/ed25519" - "crypto/x509" - "encoding/hex" - "errors" - "flag" - "fmt" - "io" - "net" - "os" - "os/exec" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" -) - -// TLS reference tests run a connection against a reference implementation -// (OpenSSL) of TLS and record the bytes of the resulting connection. The Go -// code, during a test, is configured with deterministic randomness and so the -// reference test can be reproduced exactly in the future. -// -// In order to save everyone who wishes to run the tests from needing the -// reference implementation installed, the reference connections are saved in -// files in the testdata directory. Thus running the tests involves nothing -// external, but creating and updating them requires the reference -// implementation. -// -// Tests can be updated by running them with the -update flag. This will cause -// the test files for failing tests to be regenerated. Since the reference -// implementation will always generate fresh random numbers, large parts of the -// reference connection will always change. - -var ( - update = flag.Bool("update", false, "update golden files on failure") - fast = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests") - keyFile = flag.String("keylog", "", "destination file for KeyLogWriter") -) - -func runTestAndUpdateIfNeeded(t *testing.T, name string, run func(t *testing.T, update bool), wait bool) { - success := t.Run(name, func(t *testing.T) { - if !*update && !wait { - t.Parallel() - } - run(t, false) - }) - - if !success && *update { - t.Run(name+"#update", func(t *testing.T) { - run(t, true) - }) - } -} - -// checkOpenSSLVersion ensures that the version of OpenSSL looks reasonable -// before updating the test data. -func checkOpenSSLVersion() error { - if !*update { - return nil - } - - openssl := exec.Command("openssl", "version") - output, err := openssl.CombinedOutput() - if err != nil { - return err - } - - version := string(output) - if strings.HasPrefix(version, "OpenSSL 1.1.1") { - return nil - } - - println("***********************************************") - println("") - println("You need to build OpenSSL 1.1.1 from source in order") - println("to update the test data.") - println("") - println("Configure it with:") - println("./Configure enable-weak-ssl-ciphers no-shared") - println("and then add the apps/ directory at the front of your PATH.") - println("***********************************************") - - return errors.New("version of OpenSSL does not appear to be suitable for updating test data") -} - -// recordingConn is a net.Conn that records the traffic that passes through it. -// WriteTo can be used to produce output that can be later be loaded with -// ParseTestData. -type recordingConn struct { - net.Conn - sync.Mutex - flows [][]byte - reading bool -} - -func (r *recordingConn) Read(b []byte) (n int, err error) { - if n, err = r.Conn.Read(b); n == 0 { - return - } - b = b[:n] - - r.Lock() - defer r.Unlock() - - if l := len(r.flows); l == 0 || !r.reading { - buf := make([]byte, len(b)) - copy(buf, b) - r.flows = append(r.flows, buf) - } else { - r.flows[l-1] = append(r.flows[l-1], b[:n]...) - } - r.reading = true - return -} - -func (r *recordingConn) Write(b []byte) (n int, err error) { - if n, err = r.Conn.Write(b); n == 0 { - return - } - b = b[:n] - - r.Lock() - defer r.Unlock() - - if l := len(r.flows); l == 0 || r.reading { - buf := make([]byte, len(b)) - copy(buf, b) - r.flows = append(r.flows, buf) - } else { - r.flows[l-1] = append(r.flows[l-1], b[:n]...) - } - r.reading = false - return -} - -// WriteTo writes Go source code to w that contains the recorded traffic. -func (r *recordingConn) WriteTo(w io.Writer) (int64, error) { - // TLS always starts with a client to server flow. - clientToServer := true - var written int64 - for i, flow := range r.flows { - source, dest := "client", "server" - if !clientToServer { - source, dest = dest, source - } - n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest) - written += int64(n) - if err != nil { - return written, err - } - dumper := hex.Dumper(w) - n, err = dumper.Write(flow) - written += int64(n) - if err != nil { - return written, err - } - err = dumper.Close() - if err != nil { - return written, err - } - clientToServer = !clientToServer - } - return written, nil -} - -func parseTestData(r io.Reader) (flows [][]byte, err error) { - var currentFlow []byte - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - // If the line starts with ">>> " then it marks the beginning - // of a new flow. - if strings.HasPrefix(line, ">>> ") { - if len(currentFlow) > 0 || len(flows) > 0 { - flows = append(flows, currentFlow) - currentFlow = nil - } - continue - } - - // Otherwise the line is a line of hex dump that looks like: - // 00000170 fc f5 06 bf (...) |.....X{&?......!| - // (Some bytes have been omitted from the middle section.) - _, after, ok := strings.Cut(line, " ") - if !ok { - return nil, errors.New("invalid test data") - } - line = after - - before, _, ok := strings.Cut(line, "|") - if !ok { - return nil, errors.New("invalid test data") - } - line = before - - hexBytes := strings.Fields(line) - for _, hexByte := range hexBytes { - val, err := strconv.ParseUint(hexByte, 16, 8) - if err != nil { - return nil, errors.New("invalid hex byte in test data: " + err.Error()) - } - currentFlow = append(currentFlow, byte(val)) - } - } - - if len(currentFlow) > 0 { - flows = append(flows, currentFlow) - } - - return flows, nil -} - -// tempFile creates a temp file containing contents and returns its path. -func tempFile(contents string) string { - file, err := os.CreateTemp("", "go-tls-test") - if err != nil { - panic("failed to create temp file: " + err.Error()) - } - path := file.Name() - file.WriteString(contents) - file.Close() - return path -} - -// localListener is set up by TestMain and used by localPipe to create Conn -// pairs like net.Pipe, but connected by an actual buffered TCP connection. -var localListener struct { - mu sync.Mutex - addr net.Addr - ch chan net.Conn -} - -const localFlakes = 0 // change to 1 or 2 to exercise localServer/localPipe handling of mismatches - -func localServer(l net.Listener) { - for n := 0; ; n++ { - c, err := l.Accept() - if err != nil { - return - } - if localFlakes == 1 && n%2 == 0 { - c.Close() - continue - } - localListener.ch <- c - } -} - -var isConnRefused = func(err error) bool { return false } - -func localPipe(t testing.TB) (net.Conn, net.Conn) { - localListener.mu.Lock() - defer localListener.mu.Unlock() - - addr := localListener.addr - - var err error -Dialing: - // We expect a rare mismatch, but probably not 5 in a row. - for i := 0; i < 5; i++ { - tooSlow := time.NewTimer(1 * time.Second) - defer tooSlow.Stop() - var c1 net.Conn - c1, err = net.Dial(addr.Network(), addr.String()) - if err != nil { - if runtime.GOOS == "dragonfly" && (isConnRefused(err) || os.IsTimeout(err)) { - // golang.org/issue/29583: Dragonfly sometimes returns a spurious - // ECONNREFUSED or ETIMEDOUT. - <-tooSlow.C - continue - } - t.Fatalf("localPipe: %v", err) - } - if localFlakes == 2 && i == 0 { - c1.Close() - continue - } - for { - select { - case <-tooSlow.C: - t.Logf("localPipe: timeout waiting for %v", c1.LocalAddr()) - c1.Close() - continue Dialing - - case c2 := <-localListener.ch: - if c2.RemoteAddr().String() == c1.LocalAddr().String() { - return c1, c2 - } - t.Logf("localPipe: unexpected connection: %v != %v", c2.RemoteAddr(), c1.LocalAddr()) - c2.Close() - } - } - } - - t.Fatalf("localPipe: failed to connect: %v", err) - panic("unreachable") -} - -// zeroSource is an io.Reader that returns an unlimited number of zero bytes. -type zeroSource struct{} - -func (zeroSource) Read(b []byte) (n int, err error) { - for i := range b { - b[i] = 0 - } - - return len(b), nil -} - -func allCipherSuites() []uint16 { - ids := make([]uint16, len(cipherSuites)) - for i, suite := range cipherSuites { - ids[i] = suite.id - } - - return ids -} - -var testConfig *Config - -func TestMain(m *testing.M) { - flag.Parse() - os.Exit(runMain(m)) -} - -func runMain(m *testing.M) int { - // Cipher suites preferences change based on the architecture. Force them to - // the version without AES acceleration for test consistency. - hasAESGCMHardwareSupport = false - - // Set up localPipe. - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - l, err = net.Listen("tcp6", "[::1]:0") - } - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err) - os.Exit(1) - } - localListener.ch = make(chan net.Conn) - localListener.addr = l.Addr() - defer l.Close() - go localServer(l) - - if err := checkOpenSSLVersion(); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v", err) - os.Exit(1) - } - - testConfig = &Config{ - Time: func() time.Time { return time.Unix(0, 0) }, - Rand: zeroSource{}, - Certificates: make([]Certificate, 2), - InsecureSkipVerify: true, - CipherSuites: allCipherSuites(), - MinVersion: VersionTLS10, - MaxVersion: VersionTLS13, - } - testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate} - testConfig.Certificates[0].PrivateKey = testRSAPrivateKey - testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate} - testConfig.Certificates[1].PrivateKey = testRSAPrivateKey - testConfig.BuildNameToCertificate() - if *keyFile != "" { - f, err := os.OpenFile(*keyFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - panic("failed to open -keylog file: " + err.Error()) - } - testConfig.KeyLogWriter = f - defer f.Close() - } - - return m.Run() -} - -func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { - const sentinel = "SENTINEL\n" - c, s := localPipe(t) - errChan := make(chan error) - go func() { - cli := Client(c, clientConfig) - err := cli.Handshake() - if err != nil { - errChan <- fmt.Errorf("client: %v", err) - c.Close() - return - } - defer cli.Close() - clientState = cli.ConnectionState() - buf, err := io.ReadAll(cli) - if err != nil { - t.Errorf("failed to call cli.Read: %v", err) - } - if got := string(buf); got != sentinel { - t.Errorf("read %q from TLS connection, but expected %q", got, sentinel) - } - errChan <- nil - }() - server := Server(s, serverConfig) - err = server.Handshake() - if err == nil { - serverState = server.ConnectionState() - if _, err := io.WriteString(server, sentinel); err != nil { - t.Errorf("failed to call server.Write: %v", err) - } - if err := server.Close(); err != nil { - t.Errorf("failed to call server.Close: %v", err) - } - err = <-errChan - } else { - s.Close() - <-errChan - } - return -} - -func fromHex(s string) []byte { - b, _ := hex.DecodeString(s) - return b -} - -var testRSACertificate = fromHex("3082024b308201b4a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301a310b3009060355040a1302476f310b300906035504031302476f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a38193308190300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b30190603551d1104123010820e6578616d706c652e676f6c616e67300d06092a864886f70d01010b0500038181009d30cc402b5b50a061cbbae55358e1ed8328a9581aa938a495a1ac315a1a84663d43d32dd90bf297dfd320643892243a00bccf9c7db74020015faad3166109a276fd13c3cce10c5ceeb18782f16c04ed73bbb343778d0c1cf10fa1d8408361c94c722b9daedb4606064df4c1b33ec0d1bd42d4dbfe3d1360845c21d33be9fae7") - -var testRSACertificateIssuer = fromHex("3082021930820182a003020102020900ca5e4e811a965964300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f7430819f300d06092a864886f70d010101050003818d0030818902818100d667b378bb22f34143b6cd2008236abefaf2852adf3ab05e01329e2c14834f5105df3f3073f99dab5442d45ee5f8f57b0111c8cb682fbb719a86944eebfffef3406206d898b8c1b1887797c9c5006547bb8f00e694b7a063f10839f269f2c34fff7a1f4b21fbcd6bfdfb13ac792d1d11f277b5c5b48600992203059f2a8f8cc50203010001a35d305b300e0603551d0f0101ff040403020204301d0603551d250416301406082b0601050507030106082b06010505070302300f0603551d130101ff040530030101ff30190603551d0e041204104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b050003818100c1154b4bab5266221f293766ae4138899bd4c5e36b13cee670ceeaa4cbdf4f6679017e2fe649765af545749fe4249418a56bd38a04b81e261f5ce86b8d5c65413156a50d12449554748c59a30c515bc36a59d38bddf51173e899820b282e40aa78c806526fd184fb6b4cf186ec728edffa585440d2b3225325f7ab580e87dd76") - -// testRSAPSSCertificate has signatureAlgorithm rsassaPss, but subjectPublicKeyInfo -// algorithm rsaEncryption, for use with the rsa_pss_rsae_* SignatureSchemes. -// See also TestRSAPSSKeyError. testRSAPSSCertificate is self-signed. -var testRSAPSSCertificate = fromHex("308202583082018da003020102021100f29926eb87ea8a0db9fcc247347c11b0304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012030123110300e060355040a130741636d6520436f301e170d3137313132333136313631305a170d3138313132333136313631305a30123110300e060355040a130741636d6520436f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d110408300687047f000001304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012003818100cdac4ef2ce5f8d79881042707f7cbf1b5a8a00ef19154b40151771006cd41626e5496d56da0c1a139fd84695593cb67f87765e18aa03ea067522dd78d2a589b8c92364e12838ce346c6e067b51f1a7e6f4b37ffab13f1411896679d18e880e0ba09e302ac067efca460288e9538122692297ad8093d4f7dd701424d7700a46a1") - -var testECDSACertificate = fromHex("3082020030820162020900b8bf2d47a0d2ebf4300906072a8648ce3d04013045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3132313132323135303633325a170d3232313132303135303633325a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819b301006072a8648ce3d020106052b81040023038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b300906072a8648ce3d040103818c0030818802420188a24febe245c5487d1bacf5ed989dae4770c05e1bb62fbdf1b64db76140d311a2ceee0b7e927eff769dc33b7ea53fcefa10e259ec472d7cacda4e970e15a06fd00242014dfcbe67139c2d050ebd3fa38c25c13313830d9406bbd4377af6ec7ac9862eddd711697f857c56defb31782be4c7780daecbbe9e4e3624317b6a0f399512078f2a") - -var testEd25519Certificate = fromHex("3082012e3081e1a00302010202100f431c425793941de987e4f1ad15005d300506032b657030123110300e060355040a130741636d6520436f301e170d3139303531363231333830315a170d3230303531353231333830315a30123110300e060355040a130741636d6520436f302a300506032b65700321003fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8fa34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300506032b65700341006344ed9cc4be5324539fd2108d9fe82108909539e50dc155ff2c16b71dfcab7d4dd4e09313d0a942e0b66bfe5d6748d79f50bc6ccd4b03837cf20858cdaccf0c") - -var testSNICertificate = fromHex("0441883421114c81480804c430820237308201a0a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a3023310b3009060355040a1302476f311430120603550403130b736e69746573742e636f6d30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3773075300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b0500038181007beeecff0230dbb2e7a334af65430b7116e09f327c3bbf918107fc9c66cb497493207ae9b4dbb045cb63d605ec1b5dd485bb69124d68fa298dc776699b47632fd6d73cab57042acb26f083c4087459bc5a3bb3ca4d878d7fe31016b7bc9a627438666566e3389bfaeebe6becc9a0093ceed18d0f9ac79d56f3a73f18188988ed") - -var testP256Certificate = fromHex("308201693082010ea00302010202105012dc24e1124ade4f3e153326ff27bf300a06082a8648ce3d04030230123110300e060355040a130741636d6520436f301e170d3137303533313232343934375a170d3138303533313232343934375a30123110300e060355040a130741636d6520436f3059301306072a8648ce3d020106082a8648ce3d03010703420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d1104083006820474657374300a06082a8648ce3d0403020349003046022100963712d6226c7b2bef41512d47e1434131aaca3ba585d666c924df71ac0448b3022100f4d05c725064741aef125f243cdbccaa2a5d485927831f221c43023bd5ae471a") - -var testRSAPrivateKey, _ = x509.ParsePKCS1PrivateKey(fromHex("3082025b02010002818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d702030100010281800b07fbcf48b50f1388db34b016298b8217f2092a7c9a04f77db6775a3d1279b62ee9951f7e371e9de33f015aea80660760b3951dc589a9f925ed7de13e8f520e1ccbc7498ce78e7fab6d59582c2386cc07ed688212a576ff37833bd5943483b5554d15a0b9b4010ed9bf09f207e7e9805f649240ed6c1256ed75ab7cd56d9671024100fded810da442775f5923debae4ac758390a032a16598d62f059bb2e781a9c2f41bfa015c209f966513fe3bf5a58717cbdb385100de914f88d649b7d15309fa49024100dd10978c623463a1802c52f012cfa72ff5d901f25a2292446552c2568b1840e49a312e127217c2186615aae4fb6602a4f6ebf3f3d160f3b3ad04c592f65ae41f02400c69062ca781841a09de41ed7a6d9f54adc5d693a2c6847949d9e1358555c9ac6a8d9e71653ac77beb2d3abaf7bb1183aa14278956575dbebf525d0482fd72d90240560fe1900ba36dae3022115fd952f2399fb28e2975a1c3e3d0b679660bdcb356cc189d611cfdd6d87cd5aea45aa30a2082e8b51e94c2f3dd5d5c6036a8a615ed0240143993d80ece56f877cb80048335701eb0e608cc0c1ca8c2227b52edf8f1ac99c562f2541b5ce81f0515af1c5b4770dba53383964b4b725ff46fdec3d08907df")) - -var testECDSAPrivateKey, _ = x509.ParseECPrivateKey(fromHex("3081dc0201010442019883e909ad0ac9ea3d33f9eae661f1785206970f8ca9a91672f1eedca7a8ef12bd6561bb246dda5df4b4d5e7e3a92649bc5d83a0bf92972e00e62067d0c7bd99d7a00706052b81040023a18189038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b")) - -var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) - -var testEd25519PrivateKey = ed25519.PrivateKey(fromHex("3a884965e76b3f55e5faf9615458a92354894234de3ec9f684d46d55cebf3dc63fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8f")) - -const clientCertificatePEM = ` ------BEGIN CERTIFICATE----- -MIIB7zCCAVigAwIBAgIQXBnBiWWDVW/cC8m5k5/pvDANBgkqhkiG9w0BAQsFADAS -MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDgxNzIxNTIzMVoXDTE3MDgxNzIxNTIz -MVowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC -gYEAum+qhr3Pv5/y71yUYHhv6BPy0ZZvzdkybiI3zkH5yl0prOEn2mGi7oHLEMff -NFiVhuk9GeZcJ3NgyI14AvQdpJgJoxlwaTwlYmYqqyIjxXuFOE8uCXMyp70+m63K -hAfmDzr/d8WdQYUAirab7rCkPy1MTOZCPrtRyN1IVPQMjkcCAwEAAaNGMEQwDgYD -VR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAw -DwYDVR0RBAgwBocEfwAAATANBgkqhkiG9w0BAQsFAAOBgQBGq0Si+yhU+Fpn+GKU -8ZqyGJ7ysd4dfm92lam6512oFmyc9wnTN+RLKzZ8Aa1B0jLYw9KT+RBrjpW5LBeK -o0RIvFkTgxYEiKSBXCUNmAysEbEoVr4dzWFihAm/1oDGRY2CLLTYg5vbySK3KhIR -e/oCO8HJ/+rJnahJ05XX1Q7lNQ== ------END CERTIFICATE-----` - -var clientKeyPEM = testingKey(` ------BEGIN RSA TESTING KEY----- -MIICXQIBAAKBgQC6b6qGvc+/n/LvXJRgeG/oE/LRlm/N2TJuIjfOQfnKXSms4Sfa -YaLugcsQx980WJWG6T0Z5lwnc2DIjXgC9B2kmAmjGXBpPCViZiqrIiPFe4U4Ty4J -czKnvT6brcqEB+YPOv93xZ1BhQCKtpvusKQ/LUxM5kI+u1HI3UhU9AyORwIDAQAB -AoGAEJZ03q4uuMb7b26WSQsOMeDsftdatT747LGgs3pNRkMJvTb/O7/qJjxoG+Mc -qeSj0TAZXp+PXXc3ikCECAc+R8rVMfWdmp903XgO/qYtmZGCorxAHEmR80SrfMXv -PJnznLQWc8U9nphQErR+tTESg7xWEzmFcPKwnZd1xg8ERYkCQQDTGtrFczlB2b/Z -9TjNMqUlMnTLIk/a/rPE2fLLmAYhK5sHnJdvDURaH2mF4nso0EGtENnTsh6LATnY -dkrxXGm9AkEA4hXHG2q3MnhgK1Z5hjv+Fnqd+8bcbII9WW4flFs15EKoMgS1w/PJ -zbsySaSy5IVS8XeShmT9+3lrleed4sy+UwJBAJOOAbxhfXP5r4+5R6ql66jES75w -jUCVJzJA5ORJrn8g64u2eGK28z/LFQbv9wXgCwfc72R468BdawFSLa/m2EECQGbZ -rWiFla26IVXV0xcD98VWJsTBZMlgPnSOqoMdM1kSEd4fUmlAYI/dFzV1XYSkOmVr -FhdZnklmpVDeu27P4c0CQQCuCOup0FlJSBpWY1TTfun/KMBkBatMz0VMA3d7FKIU -csPezl677Yjo8u1r/KzeI6zLg87Z8E6r6ZWNc9wBSZK6 ------END RSA TESTING KEY-----`) - -const clientECDSACertificatePEM = ` ------BEGIN CERTIFICATE----- -MIIB/DCCAV4CCQCaMIRsJjXZFzAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw -EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 -eSBMdGQwHhcNMTIxMTE0MTMyNTUzWhcNMjIxMTEyMTMyNTUzWjBBMQswCQYDVQQG -EwJBVTEMMAoGA1UECBMDTlNXMRAwDgYDVQQHEwdQeXJtb250MRIwEAYDVQQDEwlK -b2VsIFNpbmcwgZswEAYHKoZIzj0CAQYFK4EEACMDgYYABACVjJF1FMBexFe01MNv -ja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd3kfDdq0Z9kUs -jLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx+U56jb0JuK7q -ixgnTy5w/hOWusPTQBbNZU6sER7m8TAJBgcqhkjOPQQBA4GMADCBiAJCAOAUxGBg -C3JosDJdYUoCdFzCgbkWqD8pyDbHgf9stlvZcPE4O1BIKJTLCRpS8V3ujfK58PDa -2RU6+b0DeoeiIzXsAkIBo9SKeDUcSpoj0gq+KxAxnZxfvuiRs9oa9V2jI/Umi0Vw -jWVim34BmT0Y9hCaOGGbLlfk+syxis7iI6CH8OFnUes= ------END CERTIFICATE-----` - -var clientECDSAKeyPEM = testingKey(` ------BEGIN EC PARAMETERS----- -BgUrgQQAIw== ------END EC PARAMETERS----- ------BEGIN EC TESTING KEY----- -MIHcAgEBBEIBkJN9X4IqZIguiEVKMqeBUP5xtRsEv4HJEtOpOGLELwO53SD78Ew8 -k+wLWoqizS3NpQyMtrU8JFdWfj+C57UNkOugBwYFK4EEACOhgYkDgYYABACVjJF1 -FMBexFe01MNvja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd -3kfDdq0Z9kUsjLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx -+U56jb0JuK7qixgnTy5w/hOWusPTQBbNZU6sER7m8Q== ------END EC TESTING KEY-----`) - -const clientEd25519CertificatePEM = ` ------BEGIN CERTIFICATE----- -MIIBLjCB4aADAgECAhAX0YGTviqMISAQJRXoNCNPMAUGAytlcDASMRAwDgYDVQQK -EwdBY21lIENvMB4XDTE5MDUxNjIxNTQyNloXDTIwMDUxNTIxNTQyNlowEjEQMA4G -A1UEChMHQWNtZSBDbzAqMAUGAytlcAMhAAvgtWC14nkwPb7jHuBQsQTIbcd4bGkv -xRStmmNveRKRo00wSzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH -AwIwDAYDVR0TAQH/BAIwADAWBgNVHREEDzANggtleGFtcGxlLmNvbTAFBgMrZXAD -QQD8GRcqlKUx+inILn9boF2KTjRAOdazENwZ/qAicbP1j6FYDc308YUkv+Y9FN/f -7Q7hF9gRomDQijcjKsJGqjoI ------END CERTIFICATE-----` - -var clientEd25519KeyPEM = testingKey(` ------BEGIN TESTING KEY----- -MC4CAQAwBQYDK2VwBCIEINifzf07d9qx3d44e0FSbV4mC/xQxT644RRbpgNpin7I ------END TESTING KEY-----`) diff --git a/pkg/tls/handshake_unix_test.go b/pkg/tls/handshake_unix_test.go deleted file mode 100644 index 86a48f299..000000000 --- a/pkg/tls/handshake_unix_test.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build unix - -package tls - -import ( - "errors" - "syscall" -) - -func init() { - isConnRefused = func(err error) bool { - return errors.Is(err, syscall.ECONNREFUSED) - } -} diff --git a/pkg/tls/key_agreement.go b/pkg/tls/key_agreement.go deleted file mode 100644 index 2c8c5b8d7..000000000 --- a/pkg/tls/key_agreement.go +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto" - "crypto/ecdh" - "crypto/md5" - "crypto/rsa" - "crypto/sha1" - "crypto/x509" - "errors" - "fmt" - "io" -) - -// a keyAgreement implements the client and server side of a TLS key agreement -// protocol by generating and processing key exchange messages. -type keyAgreement interface { - // On the server side, the first two methods are called in order. - - // In the case that the key agreement protocol doesn't use a - // ServerKeyExchange message, generateServerKeyExchange can return nil, - // nil. - generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) - processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) - - // On the client side, the next two methods are called in order. - - // This method may not be called if the server doesn't send a - // ServerKeyExchange message. - processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error - generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) -} - -var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") -var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") - -// rsaKeyAgreement implements the standard TLS key agreement where the client -// encrypts the pre-master secret to the server's public key. -type rsaKeyAgreement struct{} - -func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - return nil, nil -} - -func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) < 2 { - return nil, errClientKeyExchange - } - ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) - if ciphertextLen != len(ckx.ciphertext)-2 { - return nil, errClientKeyExchange - } - ciphertext := ckx.ciphertext[2:] - - priv, ok := cert.PrivateKey.(crypto.Decrypter) - if !ok { - return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") - } - // Perform constant time RSA PKCS #1 v1.5 decryption - preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) - if err != nil { - return nil, err - } - // We don't check the version number in the premaster secret. For one, - // by checking it, we would leak information about the validity of the - // encrypted pre-master secret. Secondly, it provides only a small - // benefit against a downgrade attack and some implementations send the - // wrong version anyway. See the discussion at the end of section - // 7.4.7.1 of RFC 4346. - return preMasterSecret, nil -} - -func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - return errors.New("tls: unexpected ServerKeyExchange") -} - -func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - preMasterSecret := make([]byte, 48) - preMasterSecret[0] = byte(clientHello.vers >> 8) - preMasterSecret[1] = byte(clientHello.vers) - _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) - if err != nil { - return nil, nil, err - } - - rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") - } - encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) - if err != nil { - return nil, nil, err - } - ckx := new(clientKeyExchangeMsg) - ckx.ciphertext = make([]byte, len(encrypted)+2) - ckx.ciphertext[0] = byte(len(encrypted) >> 8) - ckx.ciphertext[1] = byte(len(encrypted)) - copy(ckx.ciphertext[2:], encrypted) - return preMasterSecret, ckx, nil -} - -// sha1Hash calculates a SHA1 hash over the given byte slices. -func sha1Hash(slices [][]byte) []byte { - hsha1 := sha1.New() - for _, slice := range slices { - hsha1.Write(slice) - } - return hsha1.Sum(nil) -} - -// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the -// concatenation of an MD5 and SHA1 hash. -func md5SHA1Hash(slices [][]byte) []byte { - md5sha1 := make([]byte, md5.Size+sha1.Size) - hmd5 := md5.New() - for _, slice := range slices { - hmd5.Write(slice) - } - copy(md5sha1, hmd5.Sum(nil)) - copy(md5sha1[md5.Size:], sha1Hash(slices)) - return md5sha1 -} - -// hashForServerKeyExchange hashes the given slices and returns their digest -// using the given hash function (for >= TLS 1.2) or using a default based on -// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't -// do pre-hashing, it returns the concatenation of the slices. -func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { - if sigType == signatureEd25519 { - var signed []byte - for _, slice := range slices { - signed = append(signed, slice...) - } - return signed - } - if version >= VersionTLS12 { - h := hashFunc.New() - for _, slice := range slices { - h.Write(slice) - } - digest := h.Sum(nil) - return digest - } - if sigType == signatureECDSA { - return sha1Hash(slices) - } - return md5SHA1Hash(slices) -} - -// ecdheKeyAgreement implements a TLS key agreement where the server -// generates an ephemeral EC public/private key pair and signs it. The -// pre-master secret is then calculated using ECDH. The signature may -// be ECDSA, Ed25519 or RSA. -type ecdheKeyAgreement struct { - version uint16 - isRSA bool - key *ecdh.PrivateKey - - // ckx and preMasterSecret are generated in processServerKeyExchange - // and returned in generateClientKeyExchange. - ckx *clientKeyExchangeMsg - preMasterSecret []byte -} - -func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - var curveID CurveID - for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { - curveID = c - break - } - } - - if curveID == 0 { - return nil, errors.New("tls: no supported elliptic curves offered") - } - if _, ok := curveForCurveID(curveID); !ok { - return nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - - key, err := generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, err - } - ka.key = key - - // See RFC 4492, Section 5.4. - ecdhePublic := key.PublicKey().Bytes() - serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) - serverECDHEParams[0] = 3 // named curve - serverECDHEParams[1] = byte(curveID >> 8) - serverECDHEParams[2] = byte(curveID) - serverECDHEParams[3] = byte(len(ecdhePublic)) - copy(serverECDHEParams[4:], ecdhePublic) - - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) - } - - var signatureAlgorithm SignatureScheme - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) - if err != nil { - return nil, err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return nil, err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) - if err != nil { - return nil, err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") - } - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) - - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := priv.Sign(config.rand(), signed, signOpts) - if err != nil { - return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) - } - - skx := new(serverKeyExchangeMsg) - sigAndHashLen := 0 - if ka.version >= VersionTLS12 { - sigAndHashLen = 2 - } - skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) - copy(skx.key, serverECDHEParams) - k := skx.key[len(serverECDHEParams):] - if ka.version >= VersionTLS12 { - k[0] = byte(signatureAlgorithm >> 8) - k[1] = byte(signatureAlgorithm) - k = k[2:] - } - k[0] = byte(len(sig) >> 8) - k[1] = byte(len(sig)) - copy(k[2:], sig) - - return skx, nil -} - -func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { - return nil, errClientKeyExchange - } - - peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:]) - if err != nil { - return nil, errClientKeyExchange - } - preMasterSecret, err := ka.key.ECDH(peerKey) - if err != nil { - return nil, errClientKeyExchange - } - - return preMasterSecret, nil -} - -func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - if len(skx.key) < 4 { - return errServerKeyExchange - } - if skx.key[0] != 3 { // named curve - return errors.New("tls: server selected unsupported curve") - } - curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) - - publicLen := int(skx.key[3]) - if publicLen+4 > len(skx.key) { - return errServerKeyExchange - } - serverECDHEParams := skx.key[:4+publicLen] - publicKey := serverECDHEParams[4:] - - sig := skx.key[4+publicLen:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if _, ok := curveForCurveID(curveID); !ok { - return errors.New("tls: server selected unsupported curve") - } - - key, err := generateECDHEKey(config.rand(), curveID) - if err != nil { - return err - } - ka.key = key - - peerKey, err := key.Curve().NewPublicKey(publicKey) - if err != nil { - return errServerKeyExchange - } - ka.preMasterSecret, err = key.ECDH(peerKey) - if err != nil { - return errServerKeyExchange - } - - ourPublicKey := key.PublicKey().Bytes() - ka.ckx = new(clientKeyExchangeMsg) - ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) - ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) - copy(ka.ckx.ciphertext[1:], ourPublicKey) - - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) - sig = sig[2:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) - if err != nil { - return err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return errServerKeyExchange - } - - sigLen := int(sig[0])<<8 | int(sig[1]) - if sigLen+2 != len(sig) { - return errServerKeyExchange - } - sig = sig[2:] - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) - if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - return nil -} - -func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - if ka.ckx == nil { - return nil, nil, errors.New("tls: missing ServerKeyExchange message") - } - - return ka.preMasterSecret, ka.ckx, nil -} diff --git a/pkg/tls/key_schedule.go b/pkg/tls/key_schedule.go deleted file mode 100644 index ae8f80a7c..000000000 --- a/pkg/tls/key_schedule.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto/ecdh" - "crypto/hmac" - "errors" - "fmt" - "hash" - "io" - - "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/hkdf" -) - -// This file contains the functions necessary to compute the TLS 1.3 key -// schedule. See RFC 8446, Section 7. - -const ( - resumptionBinderLabel = "res binder" - clientHandshakeTrafficLabel = "c hs traffic" - serverHandshakeTrafficLabel = "s hs traffic" - clientApplicationTrafficLabel = "c ap traffic" - serverApplicationTrafficLabel = "s ap traffic" - exporterLabel = "exp master" - resumptionLabel = "res master" - trafficUpdateLabel = "traffic upd" -) - -// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { - var hkdfLabel cryptobyte.Builder - hkdfLabel.AddUint16(uint16(length)) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte("tls13 ")) - b.AddBytes([]byte(label)) - }) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(context) - }) - hkdfLabelBytes, err := hkdfLabel.Bytes() - if err != nil { - // Rather than calling BytesOrPanic, we explicitly handle this error, in - // order to provide a reasonable error message. It should be basically - // impossible for this to panic, and routing errors back through the - // tree rooted in this function is quite painful. The labels are fixed - // size, and the context is either a fixed-length computed hash, or - // parsed from a field which has the same length limitation. As such, an - // error here is likely to only be caused during development. - // - // NOTE: another reasonable approach here might be to return a - // randomized slice if we encounter an error, which would break the - // connection, but avoid panicking. This would perhaps be safer but - // significantly more confusing to users. - panic(fmt.Errorf("failed to construct HKDF label: %s", err)) - } - out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) - if err != nil || n != length { - panic("tls: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} - -// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - if transcript == nil { - transcript = c.hash.New() - } - return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) -} - -// extract implements HKDF-Extract with the cipher suite hash. -func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { - if newSecret == nil { - newSecret = make([]byte, c.hash.Size()) - } - return hkdf.Extract(c.hash.New, newSecret, currentSecret) -} - -// nextTrafficSecret generates the next traffic secret, given the current one, -// according to RFC 8446, Section 7.2. -func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { - return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) -} - -// trafficKey generates traffic keys according to RFC 8446, Section 7.3. -func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { - key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) - iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) - return -} - -// finishedHash generates the Finished verify_data or PskBinderEntry according -// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey -// selection. -func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { - finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) - verifyData := hmac.New(c.hash.New, finishedKey) - verifyData.Write(transcript.Sum(nil)) - return verifyData.Sum(nil) -} - -// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to -// RFC 8446, Section 7.5. -func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { - expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) - return func(label string, context []byte, length int) ([]byte, error) { - secret := c.deriveSecret(expMasterSecret, label, nil) - h := c.hash.New() - h.Write(context) - return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil - } -} - -// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman -// according to RFC 8446, Section 4.2.8.2. -func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { - curve, ok := curveForCurveID(curveID) - if !ok { - return nil, errors.New("tls: internal error: unsupported curve") - } - - return curve.GenerateKey(rand) -} - -func curveForCurveID(id CurveID) (ecdh.Curve, bool) { - switch id { - case X25519: - return ecdh.X25519(), true - case CurveP256: - return ecdh.P256(), true - case CurveP384: - return ecdh.P384(), true - case CurveP521: - return ecdh.P521(), true - default: - return nil, false - } -} - -func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) { - switch curve { - case ecdh.X25519(): - return X25519, true - case ecdh.P256(): - return CurveP256, true - case ecdh.P384(): - return CurveP384, true - case ecdh.P521(): - return CurveP521, true - default: - return 0, false - } -} diff --git a/pkg/tls/key_schedule_test.go b/pkg/tls/key_schedule_test.go deleted file mode 100644 index 79ff6a62b..000000000 --- a/pkg/tls/key_schedule_test.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "encoding/hex" - "hash" - "strings" - "testing" - "unicode" -) - -// This file contains tests derived from draft-ietf-tls-tls13-vectors-07. - -func parseVector(v string) []byte { - v = strings.Map(func(c rune) rune { - if unicode.IsSpace(c) { - return -1 - } - return c - }, v) - parts := strings.Split(v, ":") - v = parts[len(parts)-1] - res, err := hex.DecodeString(v) - if err != nil { - panic(err) - } - return res -} - -func TestDeriveSecret(t *testing.T) { - chTranscript := cipherSuitesTLS13[0].hash.New() - chTranscript.Write(parseVector(` - payload (512 octets): 01 00 01 fc 03 03 1b c3 ce b6 bb e3 9c ff - 93 83 55 b5 a5 0a db 6d b2 1b 7a 6a f6 49 d7 b4 bc 41 9d 78 76 - 48 7d 95 00 00 06 13 01 13 03 13 02 01 00 01 cd 00 00 00 0b 00 - 09 00 00 06 73 65 72 76 65 72 ff 01 00 01 00 00 0a 00 14 00 12 - 00 1d 00 17 00 18 00 19 01 00 01 01 01 02 01 03 01 04 00 33 00 - 26 00 24 00 1d 00 20 e4 ff b6 8a c0 5f 8d 96 c9 9d a2 66 98 34 - 6c 6b e1 64 82 ba dd da fe 05 1a 66 b4 f1 8d 66 8f 0b 00 2a 00 - 00 00 2b 00 03 02 03 04 00 0d 00 20 00 1e 04 03 05 03 06 03 02 - 03 08 04 08 05 08 06 04 01 05 01 06 01 02 01 04 02 05 02 06 02 - 02 02 00 2d 00 02 01 01 00 1c 00 02 40 01 00 15 00 57 00 00 00 - 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 - 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 - 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 - 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 - 00 29 00 dd 00 b8 00 b2 2c 03 5d 82 93 59 ee 5f f7 af 4e c9 00 - 00 00 00 26 2a 64 94 dc 48 6d 2c 8a 34 cb 33 fa 90 bf 1b 00 70 - ad 3c 49 88 83 c9 36 7c 09 a2 be 78 5a bc 55 cd 22 60 97 a3 a9 - 82 11 72 83 f8 2a 03 a1 43 ef d3 ff 5d d3 6d 64 e8 61 be 7f d6 - 1d 28 27 db 27 9c ce 14 50 77 d4 54 a3 66 4d 4e 6d a4 d2 9e e0 - 37 25 a6 a4 da fc d0 fc 67 d2 ae a7 05 29 51 3e 3d a2 67 7f a5 - 90 6c 5b 3f 7d 8f 92 f2 28 bd a4 0d da 72 14 70 f9 fb f2 97 b5 - ae a6 17 64 6f ac 5c 03 27 2e 97 07 27 c6 21 a7 91 41 ef 5f 7d - e6 50 5e 5b fb c3 88 e9 33 43 69 40 93 93 4a e4 d3 57 fa d6 aa - cb 00 21 20 3a dd 4f b2 d8 fd f8 22 a0 ca 3c f7 67 8e f5 e8 8d - ae 99 01 41 c5 92 4d 57 bb 6f a3 1b 9e 5f 9d`)) - - type args struct { - secret []byte - label string - transcript hash.Hash - } - tests := []struct { - name string - args args - want []byte - }{ - { - `derive secret for handshake "tls13 derived"`, - args{ - parseVector(`PRK (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c e2 - 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), - "derived", - nil, - }, - parseVector(`expanded (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba - b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), - }, - { - `derive secret "tls13 c e traffic"`, - args{ - parseVector(`PRK (32 octets): 9b 21 88 e9 b2 fc 6d 64 d7 1d c3 29 90 0e 20 bb - 41 91 50 00 f6 78 aa 83 9c bb 79 7c b7 d8 33 2c`), - "c e traffic", - chTranscript, - }, - parseVector(`expanded (32 octets): 3f bb e6 a6 0d eb 66 c3 0a 32 79 5a ba 0e - ff 7e aa 10 10 55 86 e7 be 5c 09 67 8d 63 b6 ca ab 62`), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := cipherSuitesTLS13[0] - if got := c.deriveSecret(tt.args.secret, tt.args.label, tt.args.transcript); !bytes.Equal(got, tt.want) { - t.Errorf("cipherSuiteTLS13.deriveSecret() = % x, want % x", got, tt.want) - } - }) - } -} - -func TestTrafficKey(t *testing.T) { - trafficSecret := parseVector( - `PRK (32 octets): b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d 37 b4 - e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38`) - wantKey := parseVector( - `key expanded (16 octets): 3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e - e4 03 bc`) - wantIV := parseVector( - `iv expanded (12 octets): 5d 31 3e b2 67 12 76 ee 13 00 0b 30`) - - c := cipherSuitesTLS13[0] - gotKey, gotIV := c.trafficKey(trafficSecret) - if !bytes.Equal(gotKey, wantKey) { - t.Errorf("cipherSuiteTLS13.trafficKey() gotKey = % x, want % x", gotKey, wantKey) - } - if !bytes.Equal(gotIV, wantIV) { - t.Errorf("cipherSuiteTLS13.trafficKey() gotIV = % x, want % x", gotIV, wantIV) - } -} - -func TestExtract(t *testing.T) { - type args struct { - newSecret []byte - currentSecret []byte - } - tests := []struct { - name string - args args - want []byte - }{ - { - `extract secret "early"`, - args{ - nil, - nil, - }, - parseVector(`secret (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c - e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), - }, - { - `extract secret "master"`, - args{ - nil, - parseVector(`salt (32 octets): 43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25 90 b5 - 31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4`), - }, - parseVector(`secret (32 octets): 18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a - 47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19`), - }, - { - `extract secret "handshake"`, - args{ - parseVector(`IKM (32 octets): 8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d - 35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d`), - parseVector(`salt (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97 - 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), - }, - parseVector(`secret (32 octets): 1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b - 01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac`), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := cipherSuitesTLS13[0] - if got := c.extract(tt.args.newSecret, tt.args.currentSecret); !bytes.Equal(got, tt.want) { - t.Errorf("cipherSuiteTLS13.extract() = % x, want % x", got, tt.want) - } - }) - } -} diff --git a/pkg/tls/ktls.go b/pkg/tls/ktls.go deleted file mode 100644 index 69e4ba647..000000000 --- a/pkg/tls/ktls.go +++ /dev/null @@ -1,17 +0,0 @@ -package tls - -import ( - "os" - "strings" -) - -var kTLSEnabled bool - -// kTLSCipher is a placeholder to tell the record layer to skip wrapping. -type kTLSCipher struct{} - -func init() { - kTLSEnabled = strings.ToLower(os.Getenv("GOKTLS")) == "true" || - strings.ToLower(os.Getenv("GOKTLS")) == "on" || - os.Getenv("GOKTLS") == "1" -} diff --git a/pkg/tls/ktls_cipher_linux.go b/pkg/tls/ktls_cipher_linux.go deleted file mode 100644 index 1ac2c361b..000000000 --- a/pkg/tls/ktls_cipher_linux.go +++ /dev/null @@ -1,414 +0,0 @@ -//go:build linux -// +build linux - -package tls - -import ( - "fmt" - "syscall" - "unsafe" -) - -const ( - kTLS_CIPHER_AES_GCM_128 = 51 - kTLS_CIPHER_AES_GCM_128_IV_SIZE = 8 - kTLS_CIPHER_AES_GCM_128_KEY_SIZE = 16 - kTLS_CIPHER_AES_GCM_128_SALT_SIZE = 4 - kTLS_CIPHER_AES_GCM_128_TAG_SIZE = 16 - kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE = 8 - - kTLS_CIPHER_AES_GCM_256 = 52 - kTLS_CIPHER_AES_GCM_256_IV_SIZE = 8 - kTLS_CIPHER_AES_GCM_256_KEY_SIZE = 32 - kTLS_CIPHER_AES_GCM_256_SALT_SIZE = 4 - kTLS_CIPHER_AES_GCM_256_TAG_SIZE = 16 - kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE = 8 - - // AES_CCM_128 is not used as it has not been implemented in golang - kTLS_CIPHER_AES_CCM_128 = 53 - kTLS_CIPHER_AES_CCM_128_IV_SIZE = 8 - kTLS_CIPHER_AES_CCM_128_KEY_SIZE = 16 - kTLS_CIPHER_AES_CCM_128_SALT_SIZE = 4 - kTLS_CIPHER_AES_CCM_128_TAG_SIZE = 16 - kTLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE = 8 - - kTLS_CIPHER_CHACHA20_POLY1305 = 54 - kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE = 12 - kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE = 32 - kTLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE = 0 - kTLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE = 16 - kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE = 8 -) - -type kTLSCryptoInfo struct { - version uint16 - cipherType uint16 -} - -type kTLSCryptoInfoAESGCM128 struct { - info kTLSCryptoInfo - iv [kTLS_CIPHER_AES_GCM_128_IV_SIZE]byte - key [kTLS_CIPHER_AES_GCM_128_KEY_SIZE]byte - salt [kTLS_CIPHER_AES_GCM_128_SALT_SIZE]byte - recSeq [kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE]byte -} - -type kTLSCryptoInfoAESGCM256 struct { - info kTLSCryptoInfo - iv [kTLS_CIPHER_AES_GCM_256_IV_SIZE]byte - key [kTLS_CIPHER_AES_GCM_256_KEY_SIZE]byte - salt [kTLS_CIPHER_AES_GCM_256_SALT_SIZE]byte - recSeq [kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE]byte -} - -// AES_CCM_128 is not used as it has not been implemented in golang -type kTLSCryptoInfoAESCCM128 struct { - info kTLSCryptoInfo - iv [kTLS_CIPHER_AES_CCM_128_IV_SIZE]byte - key [kTLS_CIPHER_AES_CCM_128_KEY_SIZE]byte - salt [kTLS_CIPHER_AES_CCM_128_SALT_SIZE]byte - recSeq [kTLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE]byte -} - -type kTLSCryptoInfoCHACHA20POLY1305 struct { - info kTLSCryptoInfo - iv [kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE]byte - key [kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE]byte - salt [kTLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE]byte - recSeq [kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE]byte -} - -const ( - kTLSCryptoInfoSize_AES_GCM_128 = 2 + 2 + kTLS_CIPHER_AES_GCM_128_IV_SIZE + kTLS_CIPHER_AES_GCM_128_KEY_SIZE + - kTLS_CIPHER_AES_GCM_128_SALT_SIZE + kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE - - kTLSCryptoInfoSize_AES_GCM_256 = 2 + 2 + kTLS_CIPHER_AES_GCM_256_IV_SIZE + kTLS_CIPHER_AES_GCM_256_KEY_SIZE + - kTLS_CIPHER_AES_GCM_256_SALT_SIZE + kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE - - kTLSCryptoInfoSize_AES_CCM_128 = 2 + 2 + kTLS_CIPHER_AES_CCM_128_IV_SIZE + kTLS_CIPHER_AES_CCM_128_KEY_SIZE + - kTLS_CIPHER_AES_CCM_128_SALT_SIZE + kTLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE - - kTLSCryptoInfoSize_CHACHA20_POLY1305 = 2 + 2 + kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE + kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE + - kTLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE + kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE -) - -func ktlsEnableAES( - c *Conn, - version uint16, - enableFunc func(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) error, - keyLen int, - inKey, outKey, inIV, outIV []byte) error { - var ulpEnabled bool - - // Try to enable Kernel TLS TX - if !kTLSSupportTX { - return nil - } - if len(outKey) == keyLen { - if sock, ok := c.conn.(Socket); ok { - if err := enableFunc(sock.Fd(), version, TLS_TX, ulpEnabled, outKey, outIV[:], c.out.seq[:]); err != nil { - Debugln("kTLS: TLS_TX error enabling:", err) - return err - } - ulpEnabled = true - Debugln("kTLS: TLS_TX enabled") - c.out.cipher = kTLSCipher{} - // Try to enable kTLS TX zerocopy sendfile. - // Only enabled if the hardware supports the protocol. - // Otherwise, get an error message which is fine. - ktlsEnableTxZerocopySendfile(sock.Fd()) - } else { - Debugln("kTLS: TLS_TX unsupported connection type") - } - } else { - Debugln("kTLS: TLS_TX unsupported key length") - } - - // Try to enable Kernel TLS RX for TLS 1.2 or TLS 1.3 (TLS 1.3 RX is disabled on kernel < 5.19 ) - if !kTLSSupportRX || (version == VersionTLS13 && !kTLSSupportTLS13RX) { - return nil - } - if len(inKey) == keyLen { - if sock, ok := c.conn.(Socket); ok { - if err := enableFunc(sock.Fd(), version, TLS_RX, ulpEnabled, inKey, inIV[:], c.in.seq[:]); err != nil { - Debugln("kTLS: TLS_RX error enabling:", err) - return err - } - Debugln("kTLS: TLS_RX enabled") - c.in.cipher = kTLSCipher{} - // Only enable the TLS_RX_EXPECT_NO_PAD for TLS 1.3 - // TODO: safe to enable only if the remote end is trusted, otherwise - // it is an attack vector to doubling the TLS processing cost. - // See: https://docs.kernel.org/networking/tls.html#tls-rx-expect-no-pad - if version == VersionTLS13 { - ktlsEnableRxExpectNoPad(sock.Fd()) - } - } else { - Debugln("kTLS: TLS_RX unsupported connection type") - } - } else { - Debugln("kTLS: TLS_TX unsupported key length") - } - - return nil -} - -func ktlsEnableCHACHA20(c *Conn, version uint16, inKey, outKey, inIV, outIV []byte) error { - var ulpEnabled bool - - // Try to enable Kernel TLS TX - if !kTLSSupportTX { - return nil - } - if sock, ok := c.conn.(Socket); ok { - err := ktlsEnableCHACHA20POLY1305(sock.Fd(), version, TLS_TX, ulpEnabled, outKey, outIV, c.out.seq[:]) - if err != nil { - Debugln("kTLS: TLS_TX error enabling:", err) - return err - } - ulpEnabled = true - Debugln("kTLS: TLS_TX enabled") - c.out.cipher = kTLSCipher{} - // Try to enable kTLS TX zerocopy sendfile. - // Only enabled if the hardware supports the protocol. - // Otherwise, get an error message which is fine. - ktlsEnableTxZerocopySendfile(sock.Fd()) - } else { - Debugln("kTLS: TLS_TX unsupported connection type") - } - - // Try to enable Kernel TLS RX for TLS 1.2 or TLS 1.3 (TLS 1.3 RX is disabled on kernel < 5.19 ) - if !kTLSSupportRX || (version == VersionTLS13 && !kTLSSupportTLS13RX) { - return nil - } - if sock, ok := c.conn.(Socket); ok { - err := ktlsEnableCHACHA20POLY1305(sock.Fd(), version, TLS_RX, ulpEnabled, inKey[:], inIV[:], c.in.seq[:]) - if err != nil { - Debugln("kTLS: TLS_RX error enabling:", err) - return err - } - Debugln("kTLS: TLS_RX enabled") - c.in.cipher = kTLSCipher{} - // Only enable the TLS_RX_EXPECT_NO_PAD for TLS 1.3 - // TODO: safe to enable only if the remote end is trusted, otherwise - // it is an attack vector to doubling the TLS processing cost. - // See: https://docs.kernel.org/networking/tls.html#tls-rx-expect-no-pad - if version == VersionTLS13 { - ktlsEnableRxExpectNoPad(sock.Fd()) - } - } else { - Debugln("kTLS: TLS_RX unsupported connection type") - } - - return nil -} - -func ktlsEnableAES128GCM(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) (err error) { - if len(key) != kTLS_CIPHER_AES_GCM_128_KEY_SIZE { - return fmt.Errorf("kTLS: wrong key length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_128_KEY_SIZE, len(key)) - } - if version == VersionTLS12 { - // The nounce of TLS 1.2 only has 4 bytes. So, compare with kTLS_CIPHER_AES_GCM_128_SALT_SIZE only - if len(iv) != kTLS_CIPHER_AES_GCM_128_SALT_SIZE { - return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_128_SALT_SIZE, len(iv)) - } - if len(seq) != kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE { - return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE, len(seq)) - } - } else { - // The nounce of TLS 1.3 only has 12 bytes. So, compare with - // kTLS_CIPHER_AES_GCM_128_SALT_SIZE + kTLS_CIPHER_AES_GCM_128_IV_SIZE - if len(iv) != kTLS_CIPHER_AES_GCM_128_SALT_SIZE+kTLS_CIPHER_AES_GCM_128_IV_SIZE { - return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_128_SALT_SIZE+kTLS_CIPHER_AES_GCM_128_IV_SIZE, len(iv)) - } - if len(seq) != kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE { - return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE, len(seq)) - } - } - - cryptoInfo := kTLSCryptoInfoAESGCM128{ - info: kTLSCryptoInfo{ - version: version, - cipherType: kTLS_CIPHER_AES_GCM_128, - }, - } - - Debugf("\nkey: %x\niv: %x\nseq: %x", key, iv, seq) - copy(cryptoInfo.key[:], key) - copy(cryptoInfo.salt[:], iv[:kTLS_CIPHER_AES_GCM_128_SALT_SIZE]) - // TODO https://github.com/FiloSottile/go/blob/filippo%2FkTLS/src/crypto/tls/ktls.go#L73 - // the PoC of FiloSottile here is copy(cryptoInfo.iv[:], seq) - // For TLS 1.2, its IV is 0, whereas TLS 1.3 uses the rest of 8 bytes - copy(cryptoInfo.iv[:], iv[kTLS_CIPHER_AES_GCM_128_SALT_SIZE:]) - copy(cryptoInfo.recSeq[:], seq) - - // Assert padding isn't introduced by alignment requirements. - if unsafe.Sizeof(cryptoInfo) != kTLSCryptoInfoSize_AES_GCM_128 { - return fmt.Errorf("kTLS: wrong cryptoInfo size, desired: %d, actual: %d", - kTLSCryptoInfoSize_AES_GCM_128, unsafe.Sizeof(cryptoInfo)) - } - - if !skip { - err = syscall.SetsockoptString(fd, syscall.SOL_TCP, TCP_ULP, "tls") - if err != nil { - Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) - return - } - } - err = syscall.SetsockoptString(fd, SOL_TLS, opt, - string((*[kTLSCryptoInfoSize_AES_GCM_128]byte)(unsafe.Pointer(&cryptoInfo))[:])) - if err != nil { - Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) - return - } - - return err -} - -func ktlsEnableAES256GCM(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) (err error) { - if len(key) != kTLS_CIPHER_AES_GCM_256_KEY_SIZE { - return fmt.Errorf("kTLS: wrong key length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_256_KEY_SIZE, len(key)) - } - if version == VersionTLS12 { - // The nounce of TLS 1.2 only has 4 bytes. So, compare with kTLS_CIPHER_AES_GCM_256_SALT_SIZE only - if len(iv) != kTLS_CIPHER_AES_GCM_256_SALT_SIZE { - return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_256_SALT_SIZE, len(iv)) - } - if len(seq) != kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE { - return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE, len(seq)) - } - } else { - // The nounce of TLS 1.3 only has 12 bytes. So, compare with - // kTLS_CIPHER_AES_GCM_256_SALT_SIZE + kTLS_CIPHER_AES_GCM_256_IV_SIZE - if len(iv) != kTLS_CIPHER_AES_GCM_256_SALT_SIZE+kTLS_CIPHER_AES_GCM_256_IV_SIZE { - return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_256_SALT_SIZE+kTLS_CIPHER_AES_GCM_256_IV_SIZE, len(iv)) - } - if len(seq) != kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE { - return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", - kTLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE, len(seq)) - } - } - - cryptoInfo := kTLSCryptoInfoAESGCM256{ - info: kTLSCryptoInfo{ - version: version, - cipherType: kTLS_CIPHER_AES_GCM_256, - }, - } - - Debugf("key: %x\niv: %x\n seq: %x", key, iv, seq) - copy(cryptoInfo.key[:], key) - copy(cryptoInfo.salt[:], iv[:kTLS_CIPHER_AES_GCM_256_SALT_SIZE]) - // TODO https://github.com/FiloSottile/go/blob/filippo%2FkTLS/src/crypto/tls/ktls.go#L73 - // the PoC of FiloSottile here is copy(cryptoInfo.iv[:], seq) - // For TLS 1.2, its IV is 0, whereas TLS 1.3 uses the rest of 8 bytes - copy(cryptoInfo.iv[:], iv[kTLS_CIPHER_AES_GCM_256_SALT_SIZE:]) - copy(cryptoInfo.recSeq[:], seq) - - // Assert padding isn't introduced by alignment requirements. - if unsafe.Sizeof(cryptoInfo) != kTLSCryptoInfoSize_AES_GCM_256 { - return fmt.Errorf("kTLS: wrong cryptoInfo size, desired: %d, actual: %d", - kTLSCryptoInfoSize_AES_GCM_256, unsafe.Sizeof(cryptoInfo)) - } - - if !skip { - err = syscall.SetsockoptString(fd, syscall.SOL_TCP, TCP_ULP, "tls") - if err != nil { - Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) - return - } - } - err = syscall.SetsockoptString(fd, SOL_TLS, opt, - string((*[kTLSCryptoInfoSize_AES_GCM_256]byte)(unsafe.Pointer(&cryptoInfo))[:])) - if err != nil { - Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) - return - } - - return err -} - -func ktlsEnableCHACHA20POLY1305(fd int, version uint16, opt int, skip bool, key, iv, seq []byte) (err error) { - if len(key) != kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE { - return fmt.Errorf("kTLS: wrong key length, desired: %d, actual: %d", - kTLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE, len(key)) - } - if len(iv) != kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE { - return fmt.Errorf("kTLS: wrong iv length, desired: %d, actual: %d", - kTLS_CIPHER_CHACHA20_POLY1305_IV_SIZE, len(iv)) - } - if len(seq) != kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE { - return fmt.Errorf("kTLS: wrong seq length, desired: %d, actual: %d", - kTLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE, len(seq)) - } - - cryptoInfo := kTLSCryptoInfoCHACHA20POLY1305{ - info: kTLSCryptoInfo{ - version: version, - cipherType: kTLS_CIPHER_CHACHA20_POLY1305, - }, - } - - Debugf("\nkey: %x\niv: %x\nseq: %x", key, iv, seq) - copy(cryptoInfo.key[:], key) - copy(cryptoInfo.iv[:], iv) - // the salt of CHACHA20POLY1305 is 0 bytes. So, no need to copy - copy(cryptoInfo.recSeq[:], seq) - - // Assert padding isn't introduced by alignment requirements. - if unsafe.Sizeof(cryptoInfo) != kTLSCryptoInfoSize_CHACHA20_POLY1305 { - return fmt.Errorf("kTLS: wrong cryptoInfo size, desired: %d, actual: %d", - kTLSCryptoInfoSize_CHACHA20_POLY1305, unsafe.Sizeof(cryptoInfo)) - } - - if !skip { - err = syscall.SetsockoptString(fd, syscall.SOL_TCP, TCP_ULP, "tls") - if err != nil { - Debugln("kTLS: setsockopt(SOL_TCP, TCP_ULP) failed:", err) - return - } - } - err = syscall.SetsockoptString(fd, SOL_TLS, opt, - string((*[kTLSCryptoInfoSize_CHACHA20_POLY1305]byte)(unsafe.Pointer(&cryptoInfo))[:])) - if err != nil { - Debugf("kTLS: setsockopt(SOL_TLS, %d) failed: %s", opt, err) - return - } - - return err -} - -func ktlsEnableTxZerocopySendfile(fd int) (err error) { - if !kTLSSupportZEROCOPY { - return nil - } - err = syscall.SetsockoptInt(int(fd), SOL_TLS, TLS_TX_ZEROCOPY_RO, 1) - if err != nil { - Debugf("kTLS: TLS_TX Zerocopy Sendfile not Enabled. Error: %s", err) - return - } - Debugln("kTLS: TLS_TX Zerocopy Sendfile Enabled") - return -} - -func ktlsEnableRxExpectNoPad(fd int) (err error) { - if !kTLSSupportNOPAD { - return nil - } - err = syscall.SetsockoptInt(int(fd), SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1) - if err != nil { - Debugf("kTLS: TLS_RX Expect No Pad not Enabled. Error: %s", err) - return - } - Debugln("kTLS: TLS_RX Expect No Pad Enabled") - return -} diff --git a/pkg/tls/ktls_io.go b/pkg/tls/ktls_io.go deleted file mode 100644 index b460c9296..000000000 --- a/pkg/tls/ktls_io.go +++ /dev/null @@ -1,36 +0,0 @@ -package tls - -import "io" - -// LimitWriter is a copy of the standard library ioutils.LimitReader, -// applied to the writer interface. -// LimitWriter returns a Writer that writes to w -// but stops with EOF after n bytes. -// The underlying implementation is a *LimitedWriter. -func LimitWriter(w io.Writer, n int64) io.Writer { return &LimitedWriter{w, n} } - -// A LimitedWriter writes to W but limits the amount of -// data returned to just N bytes. Each call to Write -// updates N to reflect the new amount remaining. -// Write returns EOF when N <= 0 or when the underlying W returns EOF. -type LimitedWriter struct { - W io.Writer // underlying writer - N int64 // max bytes remaining -} - -func (l *LimitedWriter) Write(p []byte) (n int, err error) { - if l.N <= 0 { - return 0, io.ErrShortWrite - } - truncated := false - if int64(len(p)) > l.N { - p = p[0:l.N] - truncated = true - } - n, err = l.W.Write(p) - l.N -= int64(n) - if err == nil && truncated { - err = io.ErrShortWrite - } - return -} \ No newline at end of file diff --git a/pkg/tls/ktls_linux.go b/pkg/tls/ktls_linux.go deleted file mode 100644 index 85d051860..000000000 --- a/pkg/tls/ktls_linux.go +++ /dev/null @@ -1,534 +0,0 @@ -//go:build linux -// +build linux - -package tls - -import ( - "fmt" - "io" - "net" - "os" - "strconv" - "strings" - "unsafe" - - "golang.org/x/sys/unix" -) - -const ( - TCP_ULP = 31 - SOL_TLS = 282 - - TLS_TX = 1 - TLS_RX = 2 - TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now) - TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only - - TLS_SET_RECORD_TYPE = 1 - TLS_GET_RECORD_TYPE = 2 - - kTLSOverhead = 16 -) - -var ( - kTLSSupport bool - - // kTLSSupportTX is true when kTLSSupport is true - kTLSSupportTX bool - kTLSSupportRX bool - - // kTLSSupportAESGCM128 is true when kTLSSupport is true - kTLSSupportAESGCM128 bool - kTLSSupportAESGCM256 bool - kTLSSupportCHACHA20POLY1305 bool - - kTLSSupportTLS13TX bool - // TLS1.3 RX is buggy in kernel 5.15, got weird package lost - // TODO: test it on kernel 5.19 or 6+ - kTLSSupportTLS13RX bool - - // available in kernel >= 5.19 or 6+ - kTLSSupportZEROCOPY bool - - // available in kernel 6+ - kTLSSupportNOPAD bool -) - -func init() { - // when kernel tls module enabled, /sys/module/tls is available - _, err := os.Stat("/sys/module/tls") - if err != nil { - Debugln("kTLS: kernel tls module not enabled") - return - } - kTLSSupport = true && kTLSEnabled - Debugf("kTLS Enabled Status: %v", kTLSSupport) - // no need to check further, as KTLS is disabled - if !kTLSSupport { - return - } - - var uname unix.Utsname - if err := unix.Uname(&uname); err != nil { - Debugf("kTLS: call uname failed %v", err) - return - } - - var buf [65]byte - for i, b := range uname.Release { - buf[i] = byte(b) - } - release := string(buf[:]) - if i := strings.Index(release, "\x00"); i != -1 { - release = release[:i] - } - majorRelease := release[:strings.Index(release, ".")] - minorRelease := strings.TrimLeft(release, majorRelease+".") - minorRelease = minorRelease[:strings.Index(minorRelease, ".")] - major, err := strconv.Atoi(majorRelease) - if err != nil { - Debugf("kTLS: parse major release failed %v", err) - return - } - minor, err := strconv.Atoi(minorRelease) - if err != nil { - Debugf("kTLS: parse minor release failed %v", err) - return - } - - Debugf("Kernel Version: %s", release) - - if (major == 4 && minor >= 13) || major > 4 { - kTLSSupportTX = true - kTLSSupportAESGCM128 = true - } - - if (major == 4 && minor >= 17) || major > 4 { - kTLSSupportRX = true - } - - if (major == 5 && minor >= 1) || major > 5 { - kTLSSupportAESGCM256 = true - kTLSSupportTLS13TX = true - } - - if (major == 5 && minor >= 11) || major > 5 { - kTLSSupportCHACHA20POLY1305 = true - } - - if (major == 5 && minor >= 19) || major > 5 { - kTLSSupportZEROCOPY = true - kTLSSupportTLS13RX = true - } - - if major > 5 { - kTLSSupportNOPAD = true - } - - Debugln("======Supported Features======") - Debugf("kTLS TX: %v", kTLSSupportTX) - Debugf("kTLS RX: %v", kTLSSupportRX) - Debugf("kTLS TLS 1.3 TX: %v", kTLSSupportTLS13TX) - Debugf("kTLS TLS 1.3 RX: %v", kTLSSupportTLS13RX) - Debugf("kTLS TX ZeroCopy: %v", kTLSSupportZEROCOPY) - Debugf("kTLS RX Expected No Pad: %v", kTLSSupportNOPAD) - - Debugln("=========CipherSuites=========") - Debugf("kTLS AES-GCM-128: %v", kTLSSupportAESGCM128) - Debugf("kTLS AES-GCM-256: %v", kTLSSupportAESGCM256) - Debugf("kTLS CHACHA20POLY1305: %v", kTLSSupportCHACHA20POLY1305) -} - -func (c *Conn) ReadFrom(r io.Reader) (n int64, err error) { - if err := c.Handshake(); err != nil { - return 0, err - } - return io.Copy(c.conn, r) -} - -const maxBufferSize int64 = 4 * 1024 * 1024 - -func (c *Conn) writeToFile(f *os.File, remain int64) (written int64, err error, handled bool) { - if remain <= 0 { - return 0, nil, false - } - offset, err := f.Seek(0, io.SeekCurrent) - if err != nil { - return 0, nil, false - } - fi, err := f.Stat() - if err != nil { - return 0, nil, false - } - if offset+remain > fi.Size() { - err = f.Truncate(offset + remain) - if err != nil { - Debugf("file truncate error: %s", err) - return 0, nil, false - } - } - - // mmap must align on a page boundary - // mmap from 0, use data from offset - bytes, err := unix.Mmap(int(f.Fd()), 0, int(offset+remain), - unix.PROT_WRITE, unix.MAP_SHARED) - if err != nil { - return 0, nil, false - } - defer unix.Munmap(bytes) - - bytes = bytes[offset : offset+remain] - var ( - start = int64(0) - end = maxBufferSize - ) - - for { - if end > remain { - end = remain - } - //now := time.Now() - n, err := c.Read(bytes[start:end]) - if err != nil { - return start + int64(n), err, true - } - //log.Printf("read %d bytes, cost %dus", n, time.Since(now).Microseconds()) - start += int64(n) - if start >= remain { - break - } - - end += int64(n) - } - return remain, nil, true -} - -var maxSpliceSize int64 = 4 << 20 - -func (c *Conn) spliceToFile(f *os.File, remain int64) (written int64, err error, handled bool) { - tcpConn, ok := c.conn.(*net.TCPConn) - if !ok { - return 0, nil, false - } - sc, err := tcpConn.SyscallConn() - if err != nil { - return 0, nil, false - } - fsc, err := f.SyscallConn() - if err != nil { - return 0, nil, false - } - - var pipes [2]int - if err := unix.Pipe(pipes[:]); err != nil { - return 0, nil, false - } - - prfd, pwfd := pipes[0], pipes[1] - defer destroyTempPipe(prfd, pwfd) - - var ( - n = maxSpliceSize - m int64 - ) - - rerr := sc.Read(func(rfd uintptr) (done bool) { - for { - n = maxSpliceSize - if n > remain { - n = remain - } - // move tcp data to pipe - // FIXME should not use unix.SPLICE_F_NONBLOCK, when use this flag, ktls will not advance socket buffer - // refer: https://github.com/torvalds/linux/blob/v5.12/net/tls/tls_sw.c#L2021 - n, err = unix.Splice(int(rfd), nil, pwfd, nil, int(n), unix.SPLICE_F_MORE) - remain -= n - written += n - if err == unix.EAGAIN { - // return false to wait data from connection - err = nil - return false - } - - if err != nil { - break - } - - // move pipe data to file - werr := fsc.Write(func(wfd uintptr) (done bool) { - bump: - m, err = unix.Splice(prfd, nil, int(wfd), nil, int(n), - unix.SPLICE_F_MOVE|unix.SPLICE_F_MORE|unix.SPLICE_F_NONBLOCK) - if err != nil { - return true - } - if m < n { - n -= m - goto bump - } - return true - }) - if err == nil { - err = werr - } - if err != nil || remain <= 0 { - break - } - } - return true - }) - if err == nil { - err = rerr - } - return written, err, true -} - -// destroyTempPipe destroys a temporary pipe. -func destroyTempPipe(prfd, pwfd int) error { - err := unix.Close(prfd) - err1 := unix.Close(pwfd) - if err == nil { - return err1 - } - return err -} - -func (c *Conn) WriteTo(w io.Writer) (n int64, err error) { - if err := c.Handshake(); err != nil { - return 0, err - } - - if lw, ok := w.(*LimitedWriter); ok { - if f, ok := lw.W.(*os.File); ok { - n, err, handled := c.spliceToFile(f, lw.N) - if handled { - return n, err - } - } - } - - // FIXME read at least one record for io.EOF and so on ? - //if conn, ok := w.(*net.TCPConn); ok { - // buf := make([]byte, 16*1024) - // n, err := ktlsReadRecord(conn, buf) - // if err != nil { - // wn, _ := w.Write(buf[:n]) - // return int64(wn), err - // } - // wn, err := w.Write(buf[:n]) - // if err != nil { - // return int64(wn), err - // } - //} - return io.Copy(w, c.conn) -} - -func (c *Conn) IsKTLSTXEnabled() bool { - _, ok := c.out.cipher.(kTLSCipher) - return ok -} - -func (c *Conn) IsKTLSRXEnabled() bool { - _, ok := c.in.cipher.(kTLSCipher) - return ok -} - -func (c *Conn) enableKernelTLS(cipherSuiteID uint16, inKey, outKey, inIV, outIV []byte) error { - if !kTLSSupport { - return nil - } - switch cipherSuiteID { - // Kernel TLS 1.2 - case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_RSA_WITH_AES_128_GCM_SHA256: - if !kTLSSupportAESGCM128 { - return nil - } - Debugln("try to enable kernel tls AES_128_GCM for tls 1.2") - return ktlsEnableAES(c, VersionTLS12, ktlsEnableAES128GCM, 16, inKey, outKey, inIV, outIV) - case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_RSA_WITH_AES_256_GCM_SHA384: - if !kTLSSupportAESGCM256 { - return nil - } - Debugln("try to enable kernel tls AES_256_GCM for tls 1.2") - return ktlsEnableAES(c, VersionTLS12, ktlsEnableAES256GCM, 32, inKey, outKey, inIV, outIV) - case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: - if !kTLSSupportCHACHA20POLY1305 { - return nil - } - Debugln("try to enable kernel tls CHACHA20_POLY1305 for tls 1.2") - return ktlsEnableCHACHA20(c, VersionTLS12, inKey, outKey, inIV, outIV) - - // Kernel TLS 1.3 - case TLS_AES_128_GCM_SHA256: - if !kTLSSupportAESGCM128 || !kTLSSupportTLS13TX { - return nil - } - Debugln("try to enable kernel tls AES_128_GCM for tls 1.3") - return ktlsEnableAES(c, VersionTLS13, ktlsEnableAES128GCM, 16, inKey, outKey, inIV, outIV) - case TLS_AES_256_GCM_SHA384: - if !kTLSSupportAESGCM256 || !kTLSSupportTLS13TX { - return nil - } - Debugln("try to enable kernel tls AES_256_GCM tls 1.3") - return ktlsEnableAES(c, VersionTLS13, ktlsEnableAES256GCM, 32, inKey, outKey, inIV, outIV) - case TLS_CHACHA20_POLY1305_SHA256: - if !kTLSSupportCHACHA20POLY1305 || !kTLSSupportTLS13TX { - return nil - } - Debugln("try to enable kernel tls CHACHA20_POLY1305 for tls 1.3") - return ktlsEnableCHACHA20(c, VersionTLS13, inKey, outKey, inIV, outIV) - } - return nil -} - -func ktlsReadRecord(fd int, b []byte) (typ recordType, n int, err error) { - // cmsg for record type - buffer := make([]byte, unix.CmsgSpace(1)) - cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) - cmsg.SetLen(unix.CmsgLen(1)) - - var iov unix.Iovec - iov.Base = &b[0] - iov.SetLen(len(b)) - - var msg unix.Msghdr - msg.Control = &buffer[0] - msg.Controllen = cmsg.Len - msg.Iov = &iov - msg.Iovlen = 1 - - flags := 0 - n, err = recvmsg(uintptr(fd), &msg, flags) - if err == unix.EAGAIN { - // data is not ready, goroutine will be parked - return 0, n, err - } - // n should not be zero when err == nil - if err == nil && n == 0 { - err = io.EOF - } - - if err != nil { - Debugln("kTLS: recvmsg failed:", err) - // fix bufio panic due to n == -1 - if n == -1 { - n = 0 - } - return 0, n, err - } - - if n < 0 { - return 0, 0, fmt.Errorf("unknown size received: %d", n) - } else if n == 0 { - return 0, 0, nil - } - - if cmsg.Level != SOL_TLS { - Debugf("kTLS: unsupported cmsg level: %d", cmsg.Level) - return 0, 0, fmt.Errorf("unsupported cmsg level: %d", cmsg.Level) - } - if cmsg.Type != TLS_GET_RECORD_TYPE { - Debugf("kTLS: unsupported cmsg type: %d", cmsg.Type) - return 0, 0, fmt.Errorf("unsupported cmsg type: %d", cmsg.Type) - } - typ = recordType(buffer[unix.SizeofCmsghdr]) - Debugf("kTLS: recvmsg, type: %d, payload len: %d", typ, n) - return typ, n, nil -} - -func ktlsReadDataFromRecord(fd int, b []byte) (int, error) { - typ, n, err := ktlsReadRecord(fd, b) - if err != nil { - return n, err - } - switch typ { - case recordTypeAlert: - if n < 2 { - return 0, fmt.Errorf("ktls alert payload too short") - } - if alert(b[1]) == alertCloseNotify { - return 0, io.EOF - } - return 0, fmt.Errorf("unsupported ktls alert type: %d", b[0]) - case recordTypeApplicationData: - return n, nil - default: - return 0, fmt.Errorf("unsupported ktls record type: %d", typ) - } -} - -func recvmsg(fd uintptr, msg *unix.Msghdr, flags int) (n int, err error) { - r0, _, e1 := unix.Syscall(unix.SYS_RECVMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) - n = int(r0) - if e1 != 0 { - err = errnoErr(e1) - } - return -} - -func sendmsg(fd uintptr, msg *unix.Msghdr, flags int) (n int, err error) { - r0, _, e1 := unix.Syscall(unix.SYS_SENDMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) - n = int(r0) - if e1 != 0 { - err = errnoErr(e1) - } - return -} - -// Do the interface allocations only once for common -// Errno values. -var ( - errEAGAIN error = unix.EAGAIN - errEINVAL error = unix.EINVAL - errENOENT error = unix.ENOENT -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e unix.Errno) error { - switch e { - case 0: - return nil - case unix.EAGAIN: - return errEAGAIN - case unix.EINVAL: - return errEINVAL - case unix.ENOENT: - return errENOENT - } - return e -} - -func ktlsSendCtrlMessage(fd int, typ recordType, b []byte) (int, error) { - // cmsg for record type - buffer := make([]byte, unix.CmsgSpace(1)) - cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) - cmsg.SetLen(unix.CmsgLen(1)) - buffer[unix.SizeofCmsghdr] = byte(typ) - cmsg.Level = SOL_TLS - cmsg.Type = TLS_SET_RECORD_TYPE - - var iov unix.Iovec - iov.Base = &b[0] - iov.SetLen(len(b)) - - var msg unix.Msghdr - msg.Control = &buffer[0] - msg.Controllen = cmsg.Len - msg.Iov = &iov - msg.Iovlen = 1 - - var n int - flags := 0 - n, err := sendmsg(uintptr(fd), &msg, flags) - if err == unix.EAGAIN { - // data is not ready, goroutine will be parked - return n, err - } - if err != nil { - Debugln("kTLS: sendmsg failed:", err) - } - - Debugf("kTLS: sendmsg, type: %d, payload len: %d", typ, len(b)) - return n, err -} diff --git a/pkg/tls/ktls_log_debug.go b/pkg/tls/ktls_log_debug.go deleted file mode 100644 index da1352c29..000000000 --- a/pkg/tls/ktls_log_debug.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build debug -package tls - -import ( - "log" -) - -const Dev = true - -func Debugln(a ...interface{}) { - log.Println(a...) -} - -func Debugf(format string, a ...interface{}) { - log.Printf(format, a...) -} \ No newline at end of file diff --git a/pkg/tls/ktls_log_release.go b/pkg/tls/ktls_log_release.go deleted file mode 100644 index f5f5149cb..000000000 --- a/pkg/tls/ktls_log_release.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build !debug -package tls - -const Dev = false - -func Debugln(a ...interface{}) {} - -func Debugf(format string, a ...interface{}) {} \ No newline at end of file diff --git a/pkg/tls/ktls_others.go b/pkg/tls/ktls_others.go deleted file mode 100644 index 45ac05c07..000000000 --- a/pkg/tls/ktls_others.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build !linux -// +build !linux - -package tls - -import ( - "net" -) - -const kTLSOverhead = 0 - -func (c *Conn) enableKernelTLS(cipherSuiteID uint16, inKey, outKey, inIV, outIV []byte) error { - return nil -} - -func ktlsSendCtrlMessage(fd int, typ recordType, b []byte) (int, error) { - panic("not implement") -} - -func ktlsReadDataFromRecord(fd int, b []byte) (int, error) { - panic("not implement") -} - -func ktlsReadRecord(fd int, b []byte) (recordType, int, error) { - panic("not implement") -} diff --git a/pkg/tls/notboring.go b/pkg/tls/notboring.go deleted file mode 100644 index 7d85b39c5..000000000 --- a/pkg/tls/notboring.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !boringcrypto - -package tls - -func needFIPS() bool { return false } - -func supportedSignatureAlgorithms() []SignatureScheme { - return defaultSupportedSignatureAlgorithms -} - -func fipsMinVersion(c *Config) uint16 { panic("fipsMinVersion") } -func fipsMaxVersion(c *Config) uint16 { panic("fipsMaxVersion") } -func fipsCurvePreferences(c *Config) []CurveID { panic("fipsCurvePreferences") } -func fipsCipherSuites(c *Config) []uint16 { panic("fipsCipherSuites") } - -var fipsSupportedSignatureAlgorithms []SignatureScheme diff --git a/pkg/tls/prf.go b/pkg/tls/prf.go deleted file mode 100644 index b60166dee..000000000 --- a/pkg/tls/prf.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto" - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "errors" - "fmt" - "hash" -) - -// Split a premaster secret in two as specified in RFC 4346, Section 5. -func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { - s1 = secret[0 : (len(secret)+1)/2] - s2 = secret[len(secret)/2:] - return -} - -// pHash implements the P_hash function, as defined in RFC 4346, Section 5. -func pHash(result, secret, seed []byte, hash func() hash.Hash) { - h := hmac.New(hash, secret) - h.Write(seed) - a := h.Sum(nil) - - j := 0 - for j < len(result) { - h.Reset() - h.Write(a) - h.Write(seed) - b := h.Sum(nil) - copy(result[j:], b) - j += len(b) - - h.Reset() - h.Write(a) - a = h.Sum(nil) - } -} - -// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. -func prf10(result, secret, label, seed []byte) { - hashSHA1 := sha1.New - hashMD5 := md5.New - - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - s1, s2 := splitPreMasterSecret(secret) - pHash(result, s1, labelAndSeed, hashMD5) - result2 := make([]byte, len(result)) - pHash(result2, s2, labelAndSeed, hashSHA1) - - for i, b := range result2 { - result[i] ^= b - } -} - -// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. -func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { - return func(result, secret, label, seed []byte) { - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - pHash(result, secret, labelAndSeed, hashFunc) - } -} - -const ( - masterSecretLength = 48 // Length of a master secret in TLS 1.1. - finishedVerifyLength = 12 // Length of verify_data in a Finished message. -) - -var masterSecretLabel = []byte("master secret") -var keyExpansionLabel = []byte("key expansion") -var clientFinishedLabel = []byte("client finished") -var serverFinishedLabel = []byte("server finished") - -func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { - switch version { - case VersionTLS10, VersionTLS11: - return prf10, crypto.Hash(0) - case VersionTLS12: - if suite.flags&suiteSHA384 != 0 { - return prf12(sha512.New384), crypto.SHA384 - } - return prf12(sha256.New), crypto.SHA256 - default: - panic("unknown version") - } -} - -func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { - prf, _ := prfAndHashForVersion(version, suite) - return prf -} - -// masterFromPreMasterSecret generates the master secret from the pre-master -// secret. See RFC 5246, Section 8.1. -func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { - seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - masterSecret := make([]byte, masterSecretLength) - prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) - return masterSecret -} - -// keysFromMasterSecret generates the connection keys from the master -// secret, given the lengths of the MAC key, cipher key and IV, as defined in -// RFC 2246, Section 6.3. -func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { - seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) - seed = append(seed, serverRandom...) - seed = append(seed, clientRandom...) - - n := 2*macLen + 2*keyLen + 2*ivLen - keyMaterial := make([]byte, n) - prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) - clientMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - serverMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - clientKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - serverKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - clientIV = keyMaterial[:ivLen] - keyMaterial = keyMaterial[ivLen:] - serverIV = keyMaterial[:ivLen] - return -} - -func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { - var buffer []byte - if version >= VersionTLS12 { - buffer = []byte{} - } - - prf, hash := prfAndHashForVersion(version, cipherSuite) - if hash != 0 { - return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} - } - - return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} -} - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf func(result, secret, label, seed []byte) -} - -func (h *finishedHash) Write(msg []byte) (n int, err error) { - h.client.Write(msg) - h.server.Write(msg) - - if h.version < VersionTLS12 { - h.clientMD5.Write(msg) - h.serverMD5.Write(msg) - } - - if h.buffer != nil { - h.buffer = append(h.buffer, msg...) - } - - return len(msg), nil -} - -func (h finishedHash) Sum() []byte { - if h.version >= VersionTLS12 { - return h.client.Sum(nil) - } - - out := make([]byte, 0, md5.Size+sha1.Size) - out = h.clientMD5.Sum(out) - return h.client.Sum(out) -} - -// clientSum returns the contents of the verify_data member of a client's -// Finished message. -func (h finishedHash) clientSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) - return out -} - -// serverSum returns the contents of the verify_data member of a server's -// Finished message. -func (h finishedHash) serverSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) - return out -} - -// hashForClientCertificate returns the handshake messages so far, pre-hashed if -// necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { - if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { - panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") - } - - if sigType == signatureEd25519 { - return h.buffer - } - - if h.version >= VersionTLS12 { - hash := hashAlg.New() - hash.Write(h.buffer) - return hash.Sum(nil) - } - - if sigType == signatureECDSA { - return h.server.Sum(nil) - } - - return h.Sum() -} - -// discardHandshakeBuffer is called when there is no more need to -// buffer the entirety of the handshake messages. -func (h *finishedHash) discardHandshakeBuffer() { - h.buffer = nil -} - -// noExportedKeyingMaterial is used as a value of -// ConnectionState.ekm when renegotiation is enabled and thus -// we wish to fail all key-material export requests. -func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { - return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") -} - -// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. -func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { - return func(label string, context []byte, length int) ([]byte, error) { - switch label { - case "client finished", "server finished", "master secret", "key expansion": - // These values are reserved and may not be used. - return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) - } - - seedLen := len(serverRandom) + len(clientRandom) - if context != nil { - seedLen += 2 + len(context) - } - seed := make([]byte, 0, seedLen) - - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - if context != nil { - if len(context) >= 1<<16 { - return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") - } - seed = append(seed, byte(len(context)>>8), byte(len(context))) - seed = append(seed, context...) - } - - keyMaterial := make([]byte, length) - prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) - return keyMaterial, nil - } -} diff --git a/pkg/tls/prf_test.go b/pkg/tls/prf_test.go deleted file mode 100644 index 8233985a6..000000000 --- a/pkg/tls/prf_test.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "encoding/hex" - "testing" -) - -type testSplitPreMasterSecretTest struct { - in, out1, out2 string -} - -var testSplitPreMasterSecretTests = []testSplitPreMasterSecretTest{ - {"", "", ""}, - {"00", "00", "00"}, - {"0011", "00", "11"}, - {"001122", "0011", "1122"}, - {"00112233", "0011", "2233"}, -} - -func TestSplitPreMasterSecret(t *testing.T) { - for i, test := range testSplitPreMasterSecretTests { - in, _ := hex.DecodeString(test.in) - out1, out2 := splitPreMasterSecret(in) - s1 := hex.EncodeToString(out1) - s2 := hex.EncodeToString(out2) - if s1 != test.out1 || s2 != test.out2 { - t.Errorf("#%d: got: (%s, %s) want: (%s, %s)", i, s1, s2, test.out1, test.out2) - } - } -} - -type testKeysFromTest struct { - version uint16 - suite *cipherSuite - preMasterSecret string - clientRandom, serverRandom string - masterSecret string - clientMAC, serverMAC string - clientKey, serverKey string - macLen, keyLen int - contextKeyingMaterial, noContextKeyingMaterial string -} - -func TestKeysFromPreMasterSecret(t *testing.T) { - for i, test := range testKeysFromTests { - in, _ := hex.DecodeString(test.preMasterSecret) - clientRandom, _ := hex.DecodeString(test.clientRandom) - serverRandom, _ := hex.DecodeString(test.serverRandom) - - masterSecret := masterFromPreMasterSecret(test.version, test.suite, in, clientRandom, serverRandom) - if s := hex.EncodeToString(masterSecret); s != test.masterSecret { - t.Errorf("#%d: bad master secret %s, want %s", i, s, test.masterSecret) - continue - } - - clientMAC, serverMAC, clientKey, serverKey, _, _ := keysFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom, test.macLen, test.keyLen, 0) - clientMACString := hex.EncodeToString(clientMAC) - serverMACString := hex.EncodeToString(serverMAC) - clientKeyString := hex.EncodeToString(clientKey) - serverKeyString := hex.EncodeToString(serverKey) - if clientMACString != test.clientMAC || - serverMACString != test.serverMAC || - clientKeyString != test.clientKey || - serverKeyString != test.serverKey { - t.Errorf("#%d: got: (%s, %s, %s, %s) want: (%s, %s, %s, %s)", i, clientMACString, serverMACString, clientKeyString, serverKeyString, test.clientMAC, test.serverMAC, test.clientKey, test.serverKey) - } - - ekm := ekmFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom) - contextKeyingMaterial, err := ekm("label", []byte("context"), 32) - if err != nil { - t.Fatalf("ekmFromMasterSecret failed: %v", err) - } - - noContextKeyingMaterial, err := ekm("label", nil, 32) - if err != nil { - t.Fatalf("ekmFromMasterSecret failed: %v", err) - } - - if hex.EncodeToString(contextKeyingMaterial) != test.contextKeyingMaterial || - hex.EncodeToString(noContextKeyingMaterial) != test.noContextKeyingMaterial { - t.Errorf("#%d: got keying material: (%s, %s) want: (%s, %s)", i, contextKeyingMaterial, noContextKeyingMaterial, test.contextKeyingMaterial, test.noContextKeyingMaterial) - } - } -} - -// These test vectors were generated from GnuTLS using `gnutls-cli --insecure -d 9 ` -var testKeysFromTests = []testKeysFromTest{ - { - VersionTLS10, - cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), - "0302cac83ad4b1db3b9ab49ad05957de2a504a634a386fc600889321e1a971f57479466830ac3e6f468e87f5385fa0c5", - "4ae66303755184a3917fcb44880605fcc53baa01912b22ed94473fc69cebd558", - "4ae663020ec16e6bb5130be918cfcafd4d765979a3136a5d50c593446e4e44db", - "3d851bab6e5556e959a16bc36d66cfae32f672bfa9ecdef6096cbb1b23472df1da63dbbd9827606413221d149ed08ceb", - "805aaa19b3d2c0a0759a4b6c9959890e08480119", - "2d22f9fe519c075c16448305ceee209fc24ad109", - "d50b5771244f850cd8117a9ccafe2cf1", - "e076e33206b30507a85c32855acd0919", - 20, - 16, - "4d1bb6fc278c37d27aa6e2a13c2e079095d143272c2aa939da33d88c1c0cec22", - "93fba89599b6321ae538e27c6548ceb8b46821864318f5190d64a375e5d69d41", - }, - { - VersionTLS10, - cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), - "03023f7527316bc12cbcd69e4b9e8275d62c028f27e65c745cfcddc7ce01bd3570a111378b63848127f1c36e5f9e4890", - "4ae66364b5ea56b20ce4e25555aed2d7e67f42788dd03f3fee4adae0459ab106", - "4ae66363ab815cbf6a248b87d6b556184e945e9b97fbdf247858b0bdafacfa1c", - "7d64be7c80c59b740200b4b9c26d0baaa1c5ae56705acbcf2307fe62beb4728c19392c83f20483801cce022c77645460", - "97742ed60a0554ca13f04f97ee193177b971e3b0", - "37068751700400e03a8477a5c7eec0813ab9e0dc", - "207cddbc600d2a200abac6502053ee5c", - "df3f94f6e1eacc753b815fe16055cd43", - 20, - 16, - "2c9f8961a72b97cbe76553b5f954caf8294fc6360ef995ac1256fe9516d0ce7f", - "274f19c10291d188857ad8878e2119f5aa437d4da556601cf1337aff23154016", - }, - { - VersionTLS10, - cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), - "832d515f1d61eebb2be56ba0ef79879efb9b527504abb386fb4310ed5d0e3b1f220d3bb6b455033a2773e6d8bdf951d278a187482b400d45deb88a5d5a6bb7d6a7a1decc04eb9ef0642876cd4a82d374d3b6ff35f0351dc5d411104de431375355addc39bfb1f6329fb163b0bc298d658338930d07d313cd980a7e3d9196cac1", - "4ae663b2ee389c0de147c509d8f18f5052afc4aaf9699efe8cb05ece883d3a5e", - "4ae664d503fd4cff50cfc1fb8fc606580f87b0fcdac9554ba0e01d785bdf278e", - "1aff2e7a2c4279d0126f57a65a77a8d9d0087cf2733366699bec27eb53d5740705a8574bb1acc2abbe90e44f0dd28d6c", - "3c7647c93c1379a31a609542aa44e7f117a70085", - "0d73102994be74a575a3ead8532590ca32a526d4", - "ac7581b0b6c10d85bbd905ffbf36c65e", - "ff07edde49682b45466bd2e39464b306", - 20, - 16, - "678b0d43f607de35241dc7e9d1a7388a52c35033a1a0336d4d740060a6638fe2", - "f3b4ac743f015ef21d79978297a53da3e579ee047133f38c234d829c0f907dab", - }, -} diff --git a/pkg/tls/ticket.go b/pkg/tls/ticket.go deleted file mode 100644 index b82ccd141..000000000 --- a/pkg/tls/ticket.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "errors" - "io" - - "golang.org/x/crypto/cryptobyte" -) - -// sessionState contains the information that is serialized into a session -// ticket in order to later resume a connection. -type sessionState struct { - vers uint16 - cipherSuite uint16 - createdAt uint64 - masterSecret []byte // opaque master_secret<1..2^16-1>; - // struct { opaque certificate<1..2^24-1> } Certificate; - certificates [][]byte // Certificate certificate_list<0..2^24-1>; - - // usedOldKey is true if the ticket from which this session came from - // was encrypted with an older key and thus should be refreshed. - usedOldKey bool -} - -func (m *sessionState) marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(m.vers) - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.masterSecret) - }) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for _, cert := range m.certificates { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - } - }) - return b.Bytes() -} - -func (m *sessionState) unmarshal(data []byte) bool { - *m = sessionState{usedOldKey: m.usedOldKey} - s := cryptobyte.String(data) - if ok := s.ReadUint16(&m.vers) && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint16LengthPrefixed(&s, &m.masterSecret) && - len(m.masterSecret) != 0; !ok { - return false - } - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - if !readUint24LengthPrefixed(&certList, &cert) { - return false - } - m.certificates = append(m.certificates, cert) - } - return s.Empty() -} - -// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first -// version (revision = 0) doesn't carry any of the information needed for 0-RTT -// validation and the nonce is always empty. -type sessionStateTLS13 struct { - // uint8 version = 0x0304; - // uint8 revision = 0; - cipherSuite uint16 - createdAt uint64 - resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; - certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; -} - -func (m *sessionStateTLS13) marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(VersionTLS13) - b.AddUint8(0) // revision - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.resumptionSecret) - }) - marshalCertificate(&b, m.certificate) - return b.Bytes() -} - -func (m *sessionStateTLS13) unmarshal(data []byte) bool { - *m = sessionStateTLS13{} - s := cryptobyte.String(data) - var version uint16 - var revision uint8 - return s.ReadUint16(&version) && - version == VersionTLS13 && - s.ReadUint8(&revision) && - revision == 0 && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint8LengthPrefixed(&s, &m.resumptionSecret) && - len(m.resumptionSecret) != 0 && - unmarshalCertificate(&s, &m.certificate) && - s.Empty() -} - -func (c *Conn) encryptTicket(state []byte) ([]byte, error) { - if len(c.ticketKeys) == 0 { - return nil, errors.New("tls: internal error: session ticket keys unavailable") - } - - encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - - if _, err := io.ReadFull(c.config.rand(), iv); err != nil { - return nil, err - } - key := c.ticketKeys[0] - copy(keyName, key.keyName[:]) - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) - } - cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - mac.Sum(macBytes[:0]) - - return encrypted, nil -} - -func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { - if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { - return nil, false - } - - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] - - keyIndex := -1 - for i, candidateKey := range c.ticketKeys { - if bytes.Equal(keyName, candidateKey.keyName[:]) { - keyIndex = i - break - } - } - if keyIndex == -1 { - return nil, false - } - key := &c.ticketKeys[keyIndex] - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - expected := mac.Sum(nil) - - if subtle.ConstantTimeCompare(macBytes, expected) != 1 { - return nil, false - } - - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, false - } - plaintext = make([]byte, len(ciphertext)) - cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) - - return plaintext, keyIndex > 0 -} diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go deleted file mode 100644 index c20aa3ab1..000000000 --- a/pkg/tls/tls.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package tls partially implements TLS 1.2, as specified in RFC 5246, -// and TLS 1.3, as specified in RFC 8446. -package tls - -// BUG(agl): The crypto/tls package only implements some countermeasures -// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 -// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "net" - "os" - "strings" -) - -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: config, - } - c.handshakeFn = c.serverHandshake - return c -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func Client(conn net.Conn, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: config, - isClient: true, - } - c.handshakeFn = c.clientHandshake - return c -} - -type timeoutError struct{} - -func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - -// LoadX509KeyPair reads and parses a public/private key pair from a pair -// of files. The files must contain PEM encoded data. The certificate file -// may contain intermediate certificates following the leaf certificate to -// form a certificate chain. On successful return, Certificate.Leaf will -// be nil because the parsed form of the certificate is not retained. -func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { - certPEMBlock, err := os.ReadFile(certFile) - if err != nil { - return Certificate{}, err - } - keyPEMBlock, err := os.ReadFile(keyFile) - if err != nil { - return Certificate{}, err - } - return X509KeyPair(certPEMBlock, keyPEMBlock) -} - -// X509KeyPair parses a public/private key pair from a pair of -// PEM encoded data. On successful return, Certificate.Leaf will be nil because -// the parsed form of the certificate is not retained. -func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { - fail := func(err error) (Certificate, error) { return Certificate{}, err } - - var cert Certificate - var skippedBlockTypes []string - for { - var certDERBlock *pem.Block - certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) - if certDERBlock == nil { - break - } - if certDERBlock.Type == "CERTIFICATE" { - cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) - } else { - skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) - } - } - - if len(cert.Certificate) == 0 { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in certificate input")) - } - if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { - return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) - } - return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - - skippedBlockTypes = skippedBlockTypes[:0] - var keyDERBlock *pem.Block - for { - keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) - if keyDERBlock == nil { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in key input")) - } - if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { - return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) - } - return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { - break - } - skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) - } - - // We don't need to parse the public key for TLS, but we so do anyway - // to check that it looks sane and matches the private key. - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - return fail(err) - } - - cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) - if err != nil { - return fail(err) - } - - switch pub := x509Cert.PublicKey.(type) { - case *rsa.PublicKey: - priv, ok := cert.PrivateKey.(*rsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.N.Cmp(priv.N) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case *ecdsa.PublicKey: - priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case ed25519.PublicKey: - priv, ok := cert.PrivateKey.(ed25519.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { - return fail(errors.New("tls: private key does not match public key")) - } - default: - return fail(errors.New("tls: unknown public key algorithm")) - } - - return cert, nil -} - -// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates -// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. -// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. -func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { - if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { - return key, nil - } - if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { - switch key := key.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: - return key, nil - default: - return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") - } - } - if key, err := x509.ParseECPrivateKey(der); err == nil { - return key, nil - } - - return nil, errors.New("tls: failed to parse private key") -} diff --git a/pkg/tls/tls_test.go b/pkg/tls/tls_test.go deleted file mode 100644 index 0eb33fc77..000000000 --- a/pkg/tls/tls_test.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "strings" -) - -var rsaCertPEM = `-----BEGIN CERTIFICATE----- -MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF -MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 -ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ -hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa -rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv -zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF -MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW -r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V ------END CERTIFICATE----- -` - -func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } From d4ab0722fc25dbe6e4daca739cc420ab0a81cd16 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Sat, 1 Apr 2023 07:45:39 +0000 Subject: [PATCH 26/34] Fix golangci-lint --- connection.go | 8 ++++---- eventloop.go | 10 +++++----- pkg/tls/go120.go | 3 ++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/connection.go b/connection.go index 84ad64041..eb026a667 100644 --- a/connection.go +++ b/connection.go @@ -526,18 +526,18 @@ func (c *conn) UpgradeTLS(config *tls.Config) (err error) { // TODO: create a sync.pool to manage the TLS connection c.tlsconn = tls.Server(c, config.Clone()) - //很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 + // 很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 if c.inboundBuffer.Len() > 0 { head, tail := c.inboundBuffer.Peek(-1) - c.tlsconn.RawInputSet(head) - c.tlsconn.RawInputSet(tail) + c.tlsconn.RawInputSet(head) //nolint:errcheck + c.tlsconn.RawInputSet(tail) //nolint:errcheck c.inboundBuffer.Reset() if err := c.tlsconn.Handshake(); err != nil { return err } } - //握手失败的关了 + // 握手失败的关了 time.AfterFunc(time.Second*5, func() { if c.opened && (c.tlsconn == nil || !c.tlsconn.HandshakeComplete()) { c.Close() diff --git a/eventloop.go b/eventloop.go index 41ad23ead..4c7f7da69 100644 --- a/eventloop.go +++ b/eventloop.go @@ -160,8 +160,8 @@ func (el *eventloop) read(c *conn) error { if c.tlsconn != nil && c.tlsconn.IsKTLSRXEnabled() { // attach the gnet eventloop.buffer to tlsconn.rawInput. // So, KTLS can decrypt the data directly to the buffer without memory allocation. - // Since data is read through KTLS, there is no need to call unix.read(c.fd, el.buffer) - c.tlsconn.RawInputSet(el.buffer) + // Since data is read through KTLS, there is no need to call unix.read(c.fd, el.buffer) + c.tlsconn.RawInputSet(el.buffer) //nolint:errcheck return el.readTLS(c) } @@ -178,9 +178,9 @@ func (el *eventloop) read(c *conn) error { if c.tlsconn != nil { // attach the gnet eventloop.buffer to tlsconn.rawInput. - c.tlsconn.RawInputSet(el.buffer[:n]) + c.tlsconn.RawInputSet(el.buffer[:n]) //nolint:errcheck if !c.tlsconn.HandshakeComplete() { - //先判断是否足够一条消息 + // 先判断是否足够一条消息 data := c.tlsconn.RawInputData() if !c.tlsconn.IsRecordCompleted(data) { c.tlsconn.DataDone() @@ -191,7 +191,7 @@ func (el *eventloop) read(c *conn) error { // so no need to call tlsconn.DataDone() return el.closeConn(c, os.NewSyscallError("TLS handshake", err)) } - if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawInputData()) == 0 { //握手没成功,或者握手成功,但是没有数据黏包了 + if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawInputData()) == 0 { // 握手没成功,或者握手成功,但是没有数据黏包了 c.tlsconn.DataDone() return nil } diff --git a/pkg/tls/go120.go b/pkg/tls/go120.go index e01e88aa3..5b1017877 100644 --- a/pkg/tls/go120.go +++ b/pkg/tls/go120.go @@ -9,6 +9,7 @@ import ( gtls "github.com/0-haha/gnet_go_tls/v120" ) +//nolint:revive const ( // TLS 1.0 - 1.2 cipher suites. TLS_RSA_WITH_RC4_128_SHA uint16 = gtls.TLS_RSA_WITH_RC4_128_SHA @@ -84,4 +85,4 @@ func Server(conn net.Conn, config *Config) *Conn func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) //go:linkname X509KeyPair github.com/0-haha/gnet_go_tls/v120.X509KeyPair -func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) \ No newline at end of file +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) From a0bf9d9e185c0799b8037f9e93cdb3a085f12c4e Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Sat, 1 Apr 2023 22:57:14 +0000 Subject: [PATCH 27/34] bug: clean up the inner buffer after read TLS event (the same reason as #445) --- eventloop.go | 1 + 1 file changed, 1 insertion(+) diff --git a/eventloop.go b/eventloop.go index 4c7f7da69..0dbfa4e35 100644 --- a/eventloop.go +++ b/eventloop.go @@ -144,6 +144,7 @@ func (el *eventloop) readTLS(c *conn) error { return gerrors.ErrEngineShutdown } _, _ = c.inboundBuffer.Write(c.buffer) + c.buffer = c.buffer[:0] // all available TLS records are processed if !c.tlsconn.IsRecordCompleted(c.tlsconn.RawInputData()) { From 369338e922f1a25dfa8cc4d51c1040fbf68f5b92 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Sat, 1 Apr 2023 23:01:16 +0000 Subject: [PATCH 28/34] fix: kernel TLS 1.3 RX not supported on kernel <6 by bumpering gnet_go_tls to v120.2.0 For details, see https://github.com/0-haha/gnet_go_tls/commit/5728fd829790624e21452e217702b9c9964ded45 --- go.mod | 4 ++-- go.sum | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index b048f0d46..9b4c794e6 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,11 @@ module github.com/panjf2000/gnet/v2 require ( - github.com/0-haha/gnet_go_tls/v120 v120.0.1 + github.com/0-haha/gnet_go_tls/v120 v120.2.0 github.com/panjf2000/ants/v2 v2.7.1 github.com/stretchr/testify v1.8.2 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/zap v1.21.0 - golang.org/x/crypto v0.5.0 golang.org/x/sys v0.4.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -17,6 +16,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.8.0 // indirect + golang.org/x/crypto v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fb5427c7e..73c94476a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/0-haha/gnet_go_tls/v120 v120.0.1 h1:oFVjzpqQO4k3MfbU211oFSx8sRSALXto7T5u2bbYqTI= -github.com/0-haha/gnet_go_tls/v120 v120.0.1/go.mod h1:ZDwYfvBBzRwvZNENOXeVI+QjVDKr5r8aEKLMDSchRkI= +github.com/0-haha/gnet_go_tls/v120 v120.2.0 h1:uPq6/VoULSna7DHMR3KELdALSUaNxZ0tjFts9Yneiso= +github.com/0-haha/gnet_go_tls/v120 v120.2.0/go.mod h1:ZDwYfvBBzRwvZNENOXeVI+QjVDKr5r8aEKLMDSchRkI= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= From 0ccefca0e215e8f0448df77a652338412b53f56e Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 5 Apr 2023 03:41:28 +0000 Subject: [PATCH 29/34] fix: make comments to english --- connection.go | 5 +++-- eventloop.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/connection.go b/connection.go index eb026a667..95a3db0eb 100644 --- a/connection.go +++ b/connection.go @@ -526,7 +526,8 @@ func (c *conn) UpgradeTLS(config *tls.Config) (err error) { // TODO: create a sync.pool to manage the TLS connection c.tlsconn = tls.Server(c, config.Clone()) - // 很有可能握手包在UpgradeTls之前发过来了,这里把inboundBuffer剩余数据当做握手数据处理 + // It is very likely that the handshake packet was sent before UpgradeTls. + // So, the remaining data in the inboundBuffer is treated as handshake data here if c.inboundBuffer.Len() > 0 { head, tail := c.inboundBuffer.Peek(-1) c.tlsconn.RawInputSet(head) //nolint:errcheck @@ -537,7 +538,7 @@ func (c *conn) UpgradeTLS(config *tls.Config) (err error) { } } - // 握手失败的关了 + // handshake is failed time.AfterFunc(time.Second*5, func() { if c.opened && (c.tlsconn == nil || !c.tlsconn.HandshakeComplete()) { c.Close() diff --git a/eventloop.go b/eventloop.go index 0dbfa4e35..e12a8a794 100644 --- a/eventloop.go +++ b/eventloop.go @@ -181,7 +181,7 @@ func (el *eventloop) read(c *conn) error { // attach the gnet eventloop.buffer to tlsconn.rawInput. c.tlsconn.RawInputSet(el.buffer[:n]) //nolint:errcheck if !c.tlsconn.HandshakeComplete() { - // 先判断是否足够一条消息 + // check whether the buffer data is sufficient for a complete TLS record data := c.tlsconn.RawInputData() if !c.tlsconn.IsRecordCompleted(data) { c.tlsconn.DataDone() From 37393e2bde8d6444904dfa6faa06af74f08e7a1b Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 5 Apr 2023 03:46:17 +0000 Subject: [PATCH 30/34] fix: typos in comments --- connection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index 95a3db0eb..a08c58f7b 100644 --- a/connection.go +++ b/connection.go @@ -132,7 +132,7 @@ func (c *conn) open(buf []byte) error { func (c *conn) writeTLS(data []byte) (n int, err error) { // use tls to encrypt the data before sending it. // tlsconn will call gnet.WriteTCP() to sent the data directly. - // If gnetConn.outboundBufferis not empty, data will be + // If gnetConn.outboundBuffer is not empty, data will be // buffered in gnetConn.outboundBuffer. n, err = c.tlsconn.Write(data) return @@ -173,7 +173,7 @@ func (c *conn) writevTLS(bs [][]byte) (n int, err error) { // use tls to encrypt the data before sending it. // tlsconn will call gnet.WriteTCP() to sent the data directly. - // If gnetConn.outboundBufferis not empty, data will be + // If gnetConn.outboundBuffer is not empty, data will be // buffered in gnetConn.outboundBuffer. sent := 0 var sentN int From 2705b62025e8f64d2b251bcaaaf466bd95e5816d Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 5 Apr 2023 16:25:25 +0000 Subject: [PATCH 31/34] change package name gnet_go_tls/v120 to gnet-tls-go1-20 --- go.mod | 6 +++--- go.sum | 12 ++++++------ pkg/tls/go120.go | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index 9b4c794e6..5abf79eaf 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/panjf2000/gnet/v2 require ( - github.com/0-haha/gnet_go_tls/v120 v120.2.0 + github.com/0-haha/gnet-tls-go1-20 v0.2.0-rc.1 github.com/panjf2000/ants/v2 v2.7.1 github.com/stretchr/testify v1.8.2 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/zap v1.21.0 - golang.org/x/sys v0.4.0 + golang.org/x/sys v0.7.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -16,7 +16,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.8.0 // indirect - golang.org/x/crypto v0.5.0 // indirect + golang.org/x/crypto v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 73c94476a..bb58417b2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/0-haha/gnet_go_tls/v120 v120.2.0 h1:uPq6/VoULSna7DHMR3KELdALSUaNxZ0tjFts9Yneiso= -github.com/0-haha/gnet_go_tls/v120 v120.2.0/go.mod h1:ZDwYfvBBzRwvZNENOXeVI+QjVDKr5r8aEKLMDSchRkI= +github.com/0-haha/gnet-tls-go1-20 v0.2.0-rc.1 h1:7OY74XMmmtQuGCaIkXJy8IEemIzhEX88esFmwKAagkI= +github.com/0-haha/gnet-tls-go1-20 v0.2.0-rc.1/go.mod h1:ZipvhkWAVMtaMixCaqOh9NgQ+u4u1OCzDOohXR33VfU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= @@ -43,8 +43,8 @@ go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= -golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -60,8 +60,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/pkg/tls/go120.go b/pkg/tls/go120.go index 5b1017877..cd5a692de 100644 --- a/pkg/tls/go120.go +++ b/pkg/tls/go120.go @@ -6,7 +6,7 @@ import ( "net" _ "unsafe" - gtls "github.com/0-haha/gnet_go_tls/v120" + gtls "github.com/0-haha/gnet-tls-go1-20" ) //nolint:revive @@ -78,11 +78,11 @@ const ( VersionSSL30 = gtls.CurveP256 ) -//go:linkname Server github.com/0-haha/gnet_go_tls/v120.Server +//go:linkname Server github.com/0-haha/gnet-tls-go1-20.Server func Server(conn net.Conn, config *Config) *Conn -//go:linkname LoadX509KeyPair github.com/0-haha/gnet_go_tls/v120.LoadX509KeyPair +//go:linkname LoadX509KeyPair github.com/0-haha/gnet-tls-go1-20.LoadX509KeyPair func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) -//go:linkname X509KeyPair github.com/0-haha/gnet_go_tls/v120.X509KeyPair +//go:linkname X509KeyPair github.com/0-haha/gnet-tls-go1-20.X509KeyPair func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) From bef64fa4fd21454cf3c5f698d3f121cbdce6388e Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Wed, 5 Apr 2023 16:32:16 +0000 Subject: [PATCH 32/34] change gnet-tls-go1-20 version v1.20.2-rc.1 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 5abf79eaf..46368c878 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/panjf2000/gnet/v2 require ( - github.com/0-haha/gnet-tls-go1-20 v0.2.0-rc.1 + github.com/0-haha/gnet-tls-go1-20 v1.20.2-rc.1 github.com/panjf2000/ants/v2 v2.7.1 github.com/stretchr/testify v1.8.2 github.com/valyala/bytebufferpool v1.0.0 diff --git a/go.sum b/go.sum index bb58417b2..00659992c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/0-haha/gnet-tls-go1-20 v0.2.0-rc.1 h1:7OY74XMmmtQuGCaIkXJy8IEemIzhEX88esFmwKAagkI= -github.com/0-haha/gnet-tls-go1-20 v0.2.0-rc.1/go.mod h1:ZipvhkWAVMtaMixCaqOh9NgQ+u4u1OCzDOohXR33VfU= +github.com/0-haha/gnet-tls-go1-20 v1.20.2-rc.1 h1:cYbILYopJYpNQej70cMiYUi2vOE+ZEflt0hypVa3gOw= +github.com/0-haha/gnet-tls-go1-20 v1.20.2-rc.1/go.mod h1:ZipvhkWAVMtaMixCaqOh9NgQ+u4u1OCzDOohXR33VfU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= From ccc7c2831805b96ae9dacc33f89a17c9d0d39834 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Sun, 21 May 2023 17:27:34 +0000 Subject: [PATCH 33/34] Fix bugs caused by merging conflicts --- acceptor_unix.go | 6 ++---- eventloop_unix.go | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/acceptor_unix.go b/acceptor_unix.go index 35420d1d1..ad2f57075 100644 --- a/acceptor_unix.go +++ b/acceptor_unix.go @@ -100,16 +100,14 @@ func (el *eventloop) accept(fd int, ev netpoll.IOEvent) error { } c := newTCPConn(nfd, el, sa, el.ln.addr, remoteAddr) -<<<<<<< HEAD:acceptor.go + if el.engine.opts.TLSconfig != nil { if err = c.UpgradeTLS(el.engine.opts.TLSconfig); err != nil { return err } } - if err = el.poller.AddRead(c.pollAttachment); err != nil { -======= + if err = el.poller.AddRead(&c.pollAttachment); err != nil { ->>>>>>> f80734af8ec21935798edbfc73362cd9dae73f2b:acceptor_unix.go return err } el.connections.addConn(c, el.idx) diff --git a/eventloop_unix.go b/eventloop_unix.go index f2cbc9e23..0c9795d99 100644 --- a/eventloop_unix.go +++ b/eventloop_unix.go @@ -109,7 +109,7 @@ func (el *eventloop) readTLS(c *conn) error { } // If err is io.EOF, it can either the data is drained, // receives a close notify from the client. - return el.closeConn(c, os.NewSyscallError("TLS read", err)) + return el.close(c, os.NewSyscallError("TLS read", err)) } // load all decrypted data and make it ready for gnet to use @@ -120,7 +120,7 @@ func (el *eventloop) readTLS(c *conn) error { case None: case Close: // tls data will be cleaned up in el.closeConn() - return el.closeConn(c, nil) + return el.close(c, nil) case Shutdown: c.tlsconn.DataDone() return gerrors.ErrEngineShutdown @@ -170,9 +170,9 @@ func (el *eventloop) read(c *conn) error { return nil } if err = c.tlsconn.Handshake(); err != nil { - // closeConn will cleanup the TLS data at the end, + // close will cleanup the TLS data at the end, // so no need to call tlsconn.DataDone() - return el.closeConn(c, os.NewSyscallError("TLS handshake", err)) + return el.close(c, os.NewSyscallError("TLS handshake", err)) } if !c.tlsconn.HandshakeComplete() || len(c.tlsconn.RawInputData()) == 0 { // 握手没成功,或者握手成功,但是没有数据黏包了 c.tlsconn.DataDone() From 18c311d82f96b6d229cad62e78d4e352c778fc81 Mon Sep 17 00:00:00 2001 From: 0-haha <65914167+0-haha@users.noreply.github.com> Date: Tue, 3 Oct 2023 02:29:42 +0000 Subject: [PATCH 34/34] fix the typo --- eventloop_unix.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eventloop_unix.go b/eventloop_unix.go index f9f7c4a59..5b3e798af 100644 --- a/eventloop_unix.go +++ b/eventloop_unix.go @@ -123,7 +123,7 @@ func (el *eventloop) readTLS(c *conn) error { return el.close(c, nil) case Shutdown: c.tlsconn.DataDone() - return gerrors.ErrEngineShutdown + return errorx.ErrEngineShutdown } _, _ = c.inboundBuffer.Write(c.buffer) c.buffer = c.buffer[:0]