diff --git a/ctx.go b/ctx.go index 5437f88..ee1505e 100644 --- a/ctx.go +++ b/ctx.go @@ -76,11 +76,12 @@ func newCtx(method *C.SSL_METHOD) (*Ctx, error) { type SSLVersion int const ( - SSLv3 SSLVersion = 0x02 // Vulnerable to "POODLE" attack. - TLSv1 SSLVersion = 0x03 - TLSv1_1 SSLVersion = 0x04 - TLSv1_2 SSLVersion = 0x05 - NTLS SSLVersion = 0x06 + SSLv3 SSLVersion = 0x0300 // Vulnerable to "POODLE" attack. + TLSv1 SSLVersion = 0x0301 + TLSv1_1 SSLVersion = 0x0302 + TLSv1_2 SSLVersion = 0x0303 + TLSv1_3 SSLVersion = 0x0304 + NTLS SSLVersion = 0x0101 // AnyVersion Make sure to disable SSLv2 and SSLv3 if you use this. SSLv3 is vulnerable // to the "POODLE" attack, and SSLv2 is what, just don't even. @@ -92,20 +93,11 @@ const ( func NewCtxWithVersion(version SSLVersion) (*Ctx, error) { var enableNTLS bool var method *C.SSL_METHOD - switch version { - case SSLv3: - method = C.X_SSLv3_method() - case TLSv1: - method = C.X_TLSv1_method() - case TLSv1_1: - method = C.X_TLSv1_1_method() - case TLSv1_2: - method = C.X_TLSv1_2_method() - case NTLS: + if version == NTLS { method = C.X_NTLS_method() enableNTLS = true - case AnyVersion: - method = C.X_SSLv23_method() + } else { + method = C.TLS_method() } if method == nil { return nil, errors.New("unknown ssl/tls version") @@ -118,6 +110,12 @@ func NewCtxWithVersion(version SSLVersion) (*Ctx, error) { if enableNTLS { C.X_SSL_CTX_enable_ntls(c.ctx) + } else if version == AnyVersion { + C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(TLSv1)) + C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(TLSv1_3)) + } else { + C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(version)) + C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(version)) } return c, nil @@ -646,14 +644,29 @@ func (c *Ctx) SetSessionId(session_id []byte) error { func (c *Ctx) SetCipherList(list string) error { runtime.LockOSThread() defer runtime.UnlockOSThread() + clist := C.CString(list) defer C.free(unsafe.Pointer(clist)) + if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 { return crypto.ErrorFromErrorQueue() } return nil } +func (c *Ctx) SetCipherSuites(suites string) error { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + csuits := C.CString(suites) + defer C.free(unsafe.Pointer(csuits)) + + if int(C.SSL_CTX_set_ciphersuites(c.ctx, csuits)) == 0 { + return crypto.ErrorFromErrorQueue() + } + return nil +} + type SessionCacheModes int const ( diff --git a/examples/tlcp_client/main.go b/examples/tlcp_client/main.go index a665ecc..169829e 100644 --- a/examples/tlcp_client/main.go +++ b/examples/tlcp_client/main.go @@ -18,6 +18,8 @@ import ( "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) +var cipherSuites = "" + func main() { cipherSuite := "" signCertFile := "" @@ -28,6 +30,7 @@ func main() { connAddr := "" serverName := "" alpnProtocols := []string{"h2", "http/1.1"} + tlsVersion := "" flag.StringVar(&connAddr, "conn", "127.0.0.1:4438", "host:port") flag.StringVar(&cipherSuite, "cipher", "ECC-SM2-SM4-CBC-SM3", "cipher suite") @@ -38,92 +41,113 @@ func main() { flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file") flag.StringVar(&serverName, "servername", "", "server name") flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols") - + flag.StringVar(&tlsVersion, "version", "NTLS", "TLS version") + flag.StringVar(&cipherSuites, "ciphersuites", "ECC-SM2-SM4-CBC-SM3", "cipherSuites") flag.Parse() - ctx, err := ts.NewCtxWithVersion(ts.NTLS) + var version ts.SSLVersion + switch tlsVersion { + case "TLSv1.3": + version = ts.TLSv1_3 + case "TLSv1.2": + version = ts.TLSv1_2 + case "TLSv1.1": + version = ts.TLSv1_1 + case "TLSv1": + version = ts.TLSv1 + case "NTLS": + version = ts.NTLS + default: + version = ts.NTLS + } + ctx, err := ts.NewCtxWithVersion(version) if err != nil { - panic(err) + panic("NewCtxWithVersion failed: " + err.Error()) } if err := ctx.SetClientALPNProtos(alpnProtocols); err != nil { panic(err) } - if err := ctx.SetCipherList(cipherSuite); err != nil { - panic(err) - } - - if signCertFile != "" { - signCertPEM, err := os.ReadFile(signCertFile) - if err != nil { + if version >= ts.TLSv1_3 { + if err := ctx.SetCipherSuites(cipherSuites); err != nil { panic(err) } - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { + } else { + if err := ctx.SetCipherList(cipherSuites); err != nil { panic(err) } - - if err := ctx.UseSignCertificate(signCert); err != nil { - panic(err) + if signCertFile != "" { + signCertPEM, err := os.ReadFile(signCertFile) + if err != nil { + panic(err) + } + signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseSignCertificate(signCert); err != nil { + panic(err) + } } - } - if signKeyFile != "" { - signKeyPEM, err := os.ReadFile(signKeyFile) - if err != nil { - panic(err) - } - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) - if err != nil { - panic(err) + if signKeyFile != "" { + signKeyPEM, err := os.ReadFile(signKeyFile) + if err != nil { + panic(err) + } + signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseSignPrivateKey(signKey); err != nil { + panic(err) + } } - if err := ctx.UseSignPrivateKey(signKey); err != nil { - panic(err) + if encCertFile != "" { + encCertPEM, err := os.ReadFile(encCertFile) + if err != nil { + panic(err) + } + encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseEncryptCertificate(encCert); err != nil { + panic(err) + } } - } - if encCertFile != "" { - encCertPEM, err := os.ReadFile(encCertFile) - if err != nil { - panic(err) - } - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - panic(err) - } - - if err := ctx.UseEncryptCertificate(encCert); err != nil { - panic(err) - } - } - - if encKeyFile != "" { - encKeyPEM, err := os.ReadFile(encKeyFile) - if err != nil { - panic(err) - } + if encKeyFile != "" { + encKeyPEM, err := os.ReadFile(encKeyFile) + if err != nil { + panic(err) + } - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) - if err != nil { - panic(err) - } + encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) + if err != nil { + panic(err) + } - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - panic(err) + if err := ctx.UseEncryptPrivateKey(encKey); err != nil { + panic(err) + } } - } - if caFile != "" { - if err := ctx.LoadVerifyLocations(caFile, ""); err != nil { - panic(err) + if caFile != "" { + if err := ctx.LoadVerifyLocations(caFile, ""); err != nil { + panic(err) + } } } conn, err := ts.Dial("tcp", connAddr, ctx, ts.InsecureSkipHostVerification, serverName) if err != nil { - panic(err) + panic("connected failed" + err.Error()) } defer conn.Close() diff --git a/examples/tlcp_server/main.go b/examples/tlcp_server/main.go index c0c082c..197e875 100644 --- a/examples/tlcp_server/main.go +++ b/examples/tlcp_server/main.go @@ -11,13 +11,20 @@ import ( "bufio" "flag" "fmt" - ts "github.com/tongsuo-project/tongsuo-go-sdk" - "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "log" "net" "os" "path/filepath" "strings" + + ts "github.com/tongsuo-project/tongsuo-go-sdk" + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" +) + +var ( + cipherSuites = "" + cert = "" + key = "" ) func ReadCertificateFiles(dirPath string) (map[string]crypto.GMDoubleCertKey, error) { @@ -89,14 +96,78 @@ func handleConn(conn net.Conn) { log.Println("Close connection") } -func newNTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string, alpnProtocols []string) (net.Listener, error) { - - ctx, err := ts.NewCtxWithVersion(ts.NTLS) +func newTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string, alpnProtocols []string, tlsVersion string) (net.Listener, error) { + var version ts.SSLVersion + switch tlsVersion { + case "TLSv1.3": + version = ts.TLSv1_3 + case "TLSv1.2": + version = ts.TLSv1_2 + case "TLSv1.1": + version = ts.TLSv1_1 + case "TLSv1": + version = ts.TLSv1 + case "NTLS": + version = ts.NTLS + default: + version = ts.TLSv1_3 + } + ctx, err := ts.NewCtxWithVersion(version) if err != nil { log.Println(err) return nil, err } + if version >= ts.TLSv1_3 { + if err := ctx.SetCipherSuites(cipherSuites); err != nil { + return nil, err + } + // Load a default certificate and key for TLSv1.3 + certPEM, err := os.ReadFile(filepath.Join(cert)) + if err != nil { + log.Println(err) + return nil, err + } + + cert, err := crypto.LoadCertificateFromPEM(certPEM) + if err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.UseCertificate(cert); err != nil { + log.Println(err) + return nil, err + } + + keyPEM, err := os.ReadFile(filepath.Join(key)) + if err != nil { + log.Println(err) + return nil, err + } + + key, err := crypto.LoadPrivateKeyFromPEM(keyPEM) + if err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.UsePrivateKey(key); err != nil { + log.Println(err) + return nil, err + } + } else { + if err := ctx.SetCipherList(cipherSuites); err != nil { + return nil, err + } + // Load a default certificate and key + defaultCertKeyPair := certKeyPairs["default"] + if err := loadCertAndKey(ctx, defaultCertKeyPair); err != nil { + log.Println(err) + return nil, err + } + } + if err := ctx.LoadVerifyLocations(cafile, ""); err != nil { log.Println(err) return nil, err @@ -123,13 +194,6 @@ func newNTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCer return ts.SSLTLSExtErrOK }) - // Load a default certificate and key - defaultCertKeyPair := certKeyPairs["default"] - if err := loadCertAndKey(ctx, defaultCertKeyPair); err != nil { - log.Println(err) - return nil, err - } - // Listen for incoming connections lis, err := ts.Listen("tcp", acceptAddr, ctx) if err != nil { @@ -286,6 +350,7 @@ func main() { caFile := "" acceptAddr := "" alpnProtocols := []string{"h2", "http/1.1"} + tlsVersion := "" flag.StringVar(&acceptAddr, "accept", "127.0.0.1:4438", "host:port") flag.StringVar(&signCertFile, "sign_cert", "test/certs/sm2/server_sign.crt", "sign certificate file") @@ -294,7 +359,10 @@ func main() { flag.StringVar(&encKeyFile, "enc_key", "test/certs/sm2/server_enc.key", "encrypt private key file") flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file") flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols") - + flag.StringVar(&tlsVersion, "version", "NTLS", "TLS version") + flag.StringVar(&cipherSuites, "ciphersuites", "ECC-SM2-SM4-CBC-SM3", "cipherSuites") + flag.StringVar(&cert, "cert", "test/certs/sm2-cert.pem", "certificate file") + flag.StringVar(&key, "key", "test/certs/sm2.key", "private key file") flag.Parse() certFiles, err := ReadCertificateFiles("test/sni_certs") @@ -303,7 +371,7 @@ func main() { return } - server, err := newNTLSServer(acceptAddr, certFiles, caFile, alpnProtocols) + server, err := newTLSServer(acceptAddr, certFiles, caFile, alpnProtocols, tlsVersion) if err != nil { return } diff --git a/ntls_test.go b/ntls_test.go index 961867b..407f269 100644 --- a/ntls_test.go +++ b/ntls_test.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "fmt" - "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "log" "math/big" "net" @@ -12,11 +11,15 @@ import ( "path/filepath" "testing" "time" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) const ( ECCSM2Cipher = "ECC-SM2-WITH-SM4-SM3" ECDHESM2Cipher = "ECDHE-SM2-WITH-SM4-SM3" + TLSSMGCMCipher = "TLS_SM4_GCM_SM3" + TLSSMCCMCipher = "TLS_SM4_CCM_SM3" internalServer = true enableSNI = true @@ -1323,3 +1326,245 @@ func newNTLSServerWithSessionReuse(t *testing.T, testDir string, cacheMode Sessi return &echoServer{lis}, nil } + +func TestTLS13Connection(t *testing.T) { + // Run server + server, err := newTLS13Server(t, "test/certs") + if err != nil { + t.Error(err) + return + } + + defer server.Close() + go server.Run() + + // Run client + connAddr := "127.0.0.1:4433" + + ctx, err := NewCtxWithVersion(TLSv1_3) + if err != nil { + t.Error(err) + return + } + + conn, err := Dial("tcp", connAddr, ctx, InsecureSkipHostVerification, "") + if err != nil { + t.Log(err) + return + } + + defer conn.Close() + + // Check the tls version + tlsVersion, err := conn.GetVersion() + if err != nil { + t.Error(err) + return + } + + if tlsVersion != "TLSv1.3" { + t.Error("tls version is not TLSv1.3") + return + } + + t.Log("tls version", tlsVersion) + + request := "hello tongsuo\n" + if _, err := conn.Write([]byte(request)); err != nil { + t.Error(err) + return + } + + resp, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + t.Error(err) + return + } + + if resp != request { + t.Error("response data is not expected: ", resp) + return + } +} + +func newTLS13Server(t *testing.T, testDir string, options ...func(sslctx *Ctx) error) (*echoServer, error) { + ctx, err := NewCtxWithVersion(TLSv1_3) + if err != nil { + t.Error(err) + return nil, err + } + + for _, f := range options { + if err := f(ctx); err != nil { + t.Error(err) + return nil, err + } + } + + certPEM, err := os.ReadFile(filepath.Join(testDir, "sm2-cert.pem")) + if err != nil { + t.Error(err) + return nil, err + } + + cert, err := crypto.LoadCertificateFromPEM(certPEM) + if err != nil { + t.Error(err) + return nil, err + } + + if err := ctx.UseCertificate(cert); err != nil { + t.Error(err) + return nil, err + } + + keyPEM, err := os.ReadFile(filepath.Join(testDir, "sm2.key")) + if err != nil { + t.Error(err) + return nil, err + } + + key, err := crypto.LoadPrivateKeyFromPEM(keyPEM) + if err != nil { + t.Error(err) + return nil, err + } + + if err := ctx.UsePrivateKey(key); err != nil { + t.Error(err) + return nil, err + } + + lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + if err != nil { + t.Error(err) + return nil, err + } + + return &echoServer{lis}, nil +} + +func TestTLSv13SMCipher(t *testing.T) { + ciphers := []string{ + TLSSMGCMCipher, + TLSSMCCMCipher, + } + testCertDir := "test/certs" + + for _, cipher := range ciphers { + t.Run(cipher, func(t *testing.T) { + // Run server + server, err := newTLSv13SMCipherServer(t, testCertDir, func(sslctx *Ctx) error { + return sslctx.SetCipherSuites(cipher) + }) + if err != nil { + t.Error(err) + return + } + + defer server.Close() + go server.Run() + + // Run client + ctx, err := NewCtxWithVersion(TLSv1_3) + if err != nil { + t.Error(err) + return + } + + if err := ctx.SetCipherSuites(cipher); err != nil { + t.Error(err) + return + } + + conn, err := Dial("tcp", "127.0.0.1:4433", ctx, InsecureSkipHostVerification, "") + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + cipher, err = conn.CurrentCipher() + if err != nil { + t.Error(err) + return + } + + t.Log("current cipher", cipher) + + request := "hello tongsuo\n" + if _, err := conn.Write([]byte(request)); err != nil { + t.Error(err) + return + } + + resp, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + t.Error(err) + return + } + + if resp != request { + t.Error("response data is not expected: ", resp) + return + } + }) + } +} + +func newTLSv13SMCipherServer(t *testing.T, testDir string, options ...func(sslctx *Ctx) error) (*echoServer, error) { + ctx, err := NewCtxWithVersion(TLSv1_3) + if err != nil { + t.Error(err) + return nil, err + } + + for _, f := range options { + if err := f(ctx); err != nil { + t.Error(err) + return nil, err + } + } + + certPEM, err := os.ReadFile(filepath.Join(testDir, "sm2-cert.pem")) + if err != nil { + t.Error(err) + return nil, err + } + + cert, err := crypto.LoadCertificateFromPEM(certPEM) + if err != nil { + t.Error(err) + return nil, err + } + + if err := ctx.UseCertificate(cert); err != nil { + t.Error(err) + return nil, err + } + + keyPEM, err := os.ReadFile(filepath.Join(testDir, "sm2.key")) + if err != nil { + t.Error(err) + return nil, err + } + + key, err := crypto.LoadPrivateKeyFromPEM(keyPEM) + if err != nil { + t.Error(err) + return nil, err + } + + if err := ctx.UsePrivateKey(key); err != nil { + t.Error(err) + return nil, err + } + + lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + if err != nil { + t.Error(err) + return nil, err + } + + return &echoServer{lis}, nil +} diff --git a/shim.c b/shim.c index 08caeb8..b005996 100644 --- a/shim.c +++ b/shim.c @@ -75,6 +75,14 @@ int X_SSL_verify_cb(int ok, X509_STORE_CTX* store) { return go_ssl_verify_cb_thunk(p, ok, store); } +int X_SSL_CTX_set_max_proto_version(SSL_CTX *ctx, int version) { + return SSL_CTX_set_max_proto_version(ctx, version); +} + +int X_SSL_CTX_set_min_proto_version(SSL_CTX *ctx, int version) { + return SSL_CTX_set_min_proto_version(ctx, version); +} + const SSL_METHOD *X_SSLv23_method() { return SSLv23_method(); } diff --git a/shim.h b/shim.h index 870ec15..7472f58 100644 --- a/shim.h +++ b/shim.h @@ -85,6 +85,8 @@ extern int X_SSL_CTX_set_tlsext_ticket_key_cb(SSL_CTX *sslctx, extern int X_SSL_CTX_ticket_key_cb(SSL *s, unsigned char key_name[16], unsigned char iv[EVP_MAX_IV_LENGTH], EVP_CIPHER_CTX *cctx, HMAC_CTX *hctx, int enc); +extern int X_SSL_CTX_set_max_proto_version(SSL_CTX *ctx, int version); +extern int X_SSL_CTX_set_min_proto_version(SSL_CTX *ctx, int version); extern int X_X509_add_ref(X509* x509); extern int X_sk_X509_num(STACK_OF(X509) *sk); diff --git a/ssl_test.go b/ssl_test.go index b77127e..fcc8d8f 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -138,52 +138,68 @@ MC4CAQAwBQYDK2VwBCIEIL3QVwyuusKuLgZwZn356UHk9u1REGHbNTLtFMPKNQSb `) ) +// NetPipe creates a TCP connection pipe and returns two connections. func NetPipe(t testing.TB) (net.Conn, net.Conn) { l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } defer l.Close() - client_future := utils.NewFuture() + + // Use Future pattern to create client connection asynchronously + clientFuture := utils.NewFuture() go func() { - client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String())) + clientFuture.Set(net.Dial(l.Addr().Network(), l.Addr().String())) }() + var errs utils.ErrorGroup - server_conn, err := l.Accept() + serverConn, err := l.Accept() errs.Add(err) - client_conn, err := client_future.Get() + clientConn, err := clientFuture.Get() errs.Add(err) + err = errs.Finalize() if err != nil { - if server_conn != nil { - server_conn.Close() + if serverConn != nil { + err := serverConn.Close() + if err != nil { + t.Fatal(err) + } } - if client_conn != nil { - client_conn.(net.Conn).Close() + if clientConn != nil { + err := clientConn.(net.Conn).Close() + if err != nil { + t.Fatal(err) + } } t.Fatal(err) } - return server_conn, client_conn.(net.Conn) + return serverConn, clientConn.(net.Conn) } +// HandshakingConn interface extends net.Conn interface with Handshake method. type HandshakingConn interface { net.Conn Handshake() error } +// SimpleConnTest tests simple SSL/TLS connections. func SimpleConnTest(t testing.TB, constructor func( t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - server_conn, client_conn := NetPipe(t) - defer server_conn.Close() - defer client_conn.Close() + // Create network pipe + serverConn, clientConn := NetPipe(t) + defer serverConn.Close() + defer clientConn.Close() data := "first test string\n" - server, client := constructor(t, server_conn, client_conn) - defer close_both(server, client) + // Create SSL/TLS connections using provided constructor + server, client := constructor(t, serverConn, clientConn) + defer closeBoth(server, client) var wg sync.WaitGroup wg.Add(2) + go func() { defer wg.Done() @@ -202,9 +218,9 @@ func SimpleConnTest(t testing.TB, constructor func( t.Fatal(err) } }() + go func() { defer wg.Done() - // TODO check server.Close if err defer server.Close() err := server.Handshake() @@ -225,7 +241,8 @@ func SimpleConnTest(t testing.TB, constructor func( wg.Wait() } -func close_both(closer1, closer2 io.Closer) { +// closeBoth closes two connections. +func closeBoth(closer1, closer2 io.Closer) { var wg sync.WaitGroup wg.Add(2) go func() { @@ -239,18 +256,21 @@ func close_both(closer1, closer2 io.Closer) { wg.Wait() } +// ClosingTest tests connection closing scenarios. func ClosingTest(t testing.TB, constructor func( t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - run_test := func(server_writes bool) { - server_conn, client_conn := NetPipe(t) - defer server_conn.Close() - defer client_conn.Close() - server, client := constructor(t, server_conn, client_conn) - defer close_both(server, client) + runTest := func(serverWrites bool) { + // Create network pipe + serverConn, clientConn := NetPipe(t) + defer serverConn.Close() + defer clientConn.Close() + server, client := constructor(t, serverConn, clientConn) + defer closeBoth(server, client) + // Determine who writes and who reads based on server_writes parameter var sslconn1, sslconn2 HandshakingConn - if server_writes { + if serverWrites { sslconn1 = server sslconn2 = client } else { @@ -260,6 +280,7 @@ func ClosingTest(t testing.TB, constructor func( var wg sync.WaitGroup wg.Add(2) + go func() { defer wg.Done() _, err := sslconn1.Write([]byte("hello")) @@ -284,19 +305,24 @@ func ClosingTest(t testing.TB, constructor func( wg.Wait() } - run_test(false) - run_test(true) + // Test both client writing and server writing scenarios + runTest(false) + runTest(true) } +// ThroughputBenchmark benchmarks SSL/TLS connection throughput. func ThroughputBenchmark(b *testing.B, constructor func( t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - server_conn, client_conn := NetPipe(b) - defer server_conn.Close() - defer client_conn.Close() + // Create network pipe + serverConn, clientConn := NetPipe(b) + defer serverConn.Close() + defer clientConn.Close() - server, client := constructor(b, server_conn, client_conn) - defer close_both(server, client) + // Create SSL/TLS connections + server, client := constructor(b, serverConn, clientConn) + defer closeBoth(server, client) + // Set benchmark parameters b.SetBytes(1024) data := make([]byte, b.N*1024) _, err := io.ReadFull(rand.Reader, data[:]) @@ -307,6 +333,7 @@ func ThroughputBenchmark(b *testing.B, constructor func( b.ResetTimer() var wg sync.WaitGroup wg.Add(2) + go func() { defer wg.Done() _, err = io.Copy(client, bytes.NewReader([]byte(data))) @@ -314,6 +341,7 @@ func ThroughputBenchmark(b *testing.B, constructor func( b.Error(err) } }() + go func() { defer wg.Done() @@ -330,7 +358,8 @@ func ThroughputBenchmark(b *testing.B, constructor func( b.StopTimer() } -func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( +// StdlibConstructor creates standard library SSL/TLS connections. +func StdlibConstructor(t testing.TB, serverConn, clientConn net.Conn) ( server, client HandshakingConn) { cert, err := tls.X509KeyPair(certBytes, keyBytes) if err != nil { @@ -339,12 +368,32 @@ func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( config := &tls.Config{ Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} - server = tls.Server(server_conn, config) - client = tls.Client(client_conn, config) + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } + server = tls.Server(serverConn, config) + client = tls.Client(clientConn, config) + return server, client +} + +// StdlibTLSv13Constructor creates standard library SSL/TLS connections with TLSv1.3. +func StdlibTLSv13Constructor(t testing.TB, serverConn, clientConn net.Conn) ( + server, client HandshakingConn) { + cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + t.Fatal(err) + } + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + server = tls.Server(serverConn, config) + client = tls.Client(clientConn, config) return server, client } +// passThruVerify is used to pass through certificate verification. func passThruVerify(t testing.TB) func(bool, *CertificateStoreCtx) bool { x := func(ok bool, store *CertificateStoreCtx) bool { cert := store.GetCurrentCert() @@ -360,7 +409,8 @@ func passThruVerify(t testing.TB) func(bool, *CertificateStoreCtx) bool { return x } -func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( +// OpenSSLConstructor creates OpenSSL SSL/TLS connections. +func OpenSSLConstructor(t testing.TB, serverConn, clientConn net.Conn) ( server, client HandshakingConn) { ctx, err := NewCtx() if err != nil { @@ -387,106 +437,176 @@ func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( if err != nil { t.Fatal(err) } - server, err = Server(server_conn, ctx) + server, err = Server(serverConn, ctx) if err != nil { t.Fatal(err) } - client, err = Client(client_conn, ctx) + client, err = Client(clientConn, ctx) if err != nil { t.Fatal(err) } return server, client } -func StdlibOpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( +// OpenSSLTLSv3Constructor function is used to create SSL/TLS connections for OpenSSL and TLSv3. +func OpenSSLTLSv3Constructor(t testing.TB, serverConn, clientConn net.Conn) ( server, client HandshakingConn) { - server_std, _ := StdlibConstructor(t, server_conn, client_conn) - _, client_ssl := OpenSSLConstructor(t, server_conn, client_conn) - return server_std, client_ssl + ctx, err := NewCtxWithVersion(SSLv3) + if err != nil { + t.Fatal(err) + } + ctx.SetVerify(VerifyNone, passThruVerify(t)) + key, err := crypto.LoadPrivateKeyFromPEM(keyBytes) + if err != nil { + t.Fatal(err) + } + err = ctx.UsePrivateKey(key) + if err != nil { + t.Fatal(err) + } + cert, err := crypto.LoadCertificateFromPEM(certBytes) + if err != nil { + t.Fatal(err) + } + err = ctx.UseCertificate(cert) + if err != nil { + t.Fatal(err) + } + err = ctx.SetCipherList("AES128-SHA") + if err != nil { + t.Fatal(err) + } + server, err = Server(serverConn, ctx) + if err != nil { + t.Fatal(err) + } + client, err = Client(clientConn, ctx) + if err != nil { + t.Fatal(err) + } + return server, client } -func OpenSSLStdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( +// StdlibOpenSSLConstructor function is used to create SSL/TLS connections for the standard library and OpenSSL. +func StdlibOpenSSLConstructor(t testing.TB, serverConn, clientConn net.Conn) ( server, client HandshakingConn) { - _, client_std := StdlibConstructor(t, server_conn, client_conn) - server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn) - return server_ssl, client_std + serverStd, _ := StdlibConstructor(t, serverConn, clientConn) + _, clientSsl := OpenSSLConstructor(t, serverConn, clientConn) + return serverStd, clientSsl } +// OpenSSLStdlibConstructor function is used to create SSL/TLS connections for OpenSSL and the standard library. +func OpenSSLStdlibConstructor(t testing.TB, serverConn, clientConn net.Conn) ( + server, client HandshakingConn) { + _, clientStd := StdlibConstructor(t, serverConn, clientConn) + serverSsl, _ := OpenSSLConstructor(t, serverConn, clientConn) + return serverSsl, clientStd +} + +// TestStdlibSimple function is used to test simple connections of the standard library. func TestStdlibSimple(t *testing.T) { SimpleConnTest(t, StdlibConstructor) } +// TestStdlibTLSv13Simple function is used to test simple connections of the standard library with TLSv1.3. +func TestStdlibTLSv13Simple(t *testing.T) { + SimpleConnTest(t, StdlibTLSv13Constructor) +} + +// TestOpenSSLSimple function is used to test simple connections of OpenSSL. func TestOpenSSLSimple(t *testing.T) { SimpleConnTest(t, OpenSSLConstructor) } +// TestStdlibClosing function is used to test closing connections of the standard library. func TestStdlibClosing(t *testing.T) { ClosingTest(t, StdlibConstructor) } +// TestStdlibTLSv13Closing function is used to test closing connections of the standard library with TLSv1.3. +func TestStdlibTLSv13Closing(t *testing.T) { + ClosingTest(t, StdlibTLSv13Constructor) +} + // TODO fix this //func TestOpenSSLClosing(t *testing.T) { // ClosingTest(t, OpenSSLConstructor) //} +// BenchmarkStdlibThroughput function is used to benchmark the throughput of the standard library. func BenchmarkStdlibThroughput(b *testing.B) { ThroughputBenchmark(b, StdlibConstructor) } +// BenchmarkStdlibTLSv13Throughput function is used to benchmark the throughput of the standard library with TLSv1.3. +func BenchmarkStdlibTLSv13Throughput(b *testing.B) { + ThroughputBenchmark(b, StdlibTLSv13Constructor) +} + +// BenchmarkOpenSSLThroughput function is used to benchmark the throughput of OpenSSL. func BenchmarkOpenSSLThroughput(b *testing.B) { ThroughputBenchmark(b, OpenSSLConstructor) } +// TestStdlibOpenSSLSimple function is used to test simple connections of the standard library and OpenSSL. func TestStdlibOpenSSLSimple(t *testing.T) { SimpleConnTest(t, StdlibOpenSSLConstructor) } +// TestOpenSSLStdlibSimple function is used to test simple connections of OpenSSL and the standard library. func TestOpenSSLStdlibSimple(t *testing.T) { SimpleConnTest(t, OpenSSLStdlibConstructor) } +// TestStdlibOpenSSLClosing function is used to test closing connections of the standard library and OpenSSL. func TestStdlibOpenSSLClosing(t *testing.T) { ClosingTest(t, StdlibOpenSSLConstructor) } +// TestOpenSSLStdlibClosing function is used to test closing connections of OpenSSL and the standard library. func TestOpenSSLStdlibClosing(t *testing.T) { ClosingTest(t, OpenSSLStdlibConstructor) } +// BenchmarkStdlibOpenSSLThroughput function is used to benchmark the throughput of the standard library and OpenSSL. func BenchmarkStdlibOpenSSLThroughput(b *testing.B) { ThroughputBenchmark(b, StdlibOpenSSLConstructor) } +// BenchmarkOpenSSLStdlibThroughput function is used to benchmark the throughput of OpenSSL and the standard library. func BenchmarkOpenSSLStdlibThroughput(b *testing.B) { ThroughputBenchmark(b, OpenSSLStdlibConstructor) } +// FullDuplexRenegotiationTest function is used to test full-duplex renegotiation. func FullDuplexRenegotiationTest(t testing.TB, constructor func( t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { SSLRecordSize := 16 * 1024 - server_conn, client_conn := NetPipe(t) - defer server_conn.Close() - defer client_conn.Close() + serverConn, clientConn := NetPipe(t) + defer serverConn.Close() + defer clientConn.Close() + // Set test parameters times := 256 - data_len := 4 * SSLRecordSize - data1 := make([]byte, data_len) + dataLen := 4 * SSLRecordSize + data1 := make([]byte, dataLen) _, err := io.ReadFull(rand.Reader, data1[:]) if err != nil { t.Fatal(err) } - data2 := make([]byte, data_len) + data2 := make([]byte, dataLen) _, err = io.ReadFull(rand.Reader, data1[:]) if err != nil { t.Fatal(err) } - server, client := constructor(t, server_conn, client_conn) - defer close_both(server, client) + // Create SSL/TLS connections + server, client := constructor(t, serverConn, clientConn) + defer closeBoth(server, client) var wg sync.WaitGroup - send_func := func(sender HandshakingConn, data []byte) { + sendFunc := func(sender HandshakingConn, data []byte) { defer wg.Done() for i := 0; i < times; i++ { if i == times/2 { @@ -506,7 +626,7 @@ func FullDuplexRenegotiationTest(t testing.TB, constructor func( } } - recv_func := func(receiver net.Conn, data []byte) { + recvFunc := func(receiver net.Conn, data []byte) { defer wg.Done() buf := make([]byte, len(data)) @@ -522,40 +642,51 @@ func FullDuplexRenegotiationTest(t testing.TB, constructor func( } wg.Add(4) - go recv_func(server, data1) - go send_func(client, data1) - go send_func(server, data2) - go recv_func(client, data2) + go recvFunc(server, data1) + go sendFunc(client, data1) + go sendFunc(server, data2) + go recvFunc(client, data2) wg.Wait() } +// TestStdlibFullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library. func TestStdlibFullDuplexRenegotiation(t *testing.T) { FullDuplexRenegotiationTest(t, StdlibConstructor) } +// TestStdlibTLSv13FullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library with TLSv1.3. +func TestStdlibTLSv13FullDuplexRenegotiation(t *testing.T) { + FullDuplexRenegotiationTest(t, StdlibTLSv13Constructor) +} + +// TestOpenSSLFullDuplexRenegotiation function is used to test full-duplex renegotiation of OpenSSL. func TestOpenSSLFullDuplexRenegotiation(t *testing.T) { FullDuplexRenegotiationTest(t, OpenSSLConstructor) } +// TestOpenSSLStdlibFullDuplexRenegotiation function is used to test full-duplex renegotiation of OpenSSL and the standard library. func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) { FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor) } +// TestStdlibOpenSSLFullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library and OpenSSL. func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) { FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor) } -func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, +// LotsOfConns function is used to test the situation of a large number of connections. +func LotsOfConns(t *testing.T, payloadSize int64, loops, clients int, sleep time.Duration, newListener func(net.Listener) net.Listener, newClient func(net.Conn) (net.Conn, error)) { - tcp_listener, err := net.Listen("tcp", "localhost:0") + tcpListener, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } - ssl_listener := newListener(tcp_listener) + sslListener := newListener(tcpListener) + go func() { for { - conn, err := ssl_listener.Accept() + conn, err := sslListener.Accept() if err != nil { t.Error("failed accept: ", err) continue @@ -569,13 +700,13 @@ func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, }() for i := 0; i < loops; i++ { _, err := io.Copy(ioutil.Discard, - io.LimitReader(conn, payload_size)) + io.LimitReader(conn, payloadSize)) if err != nil { t.Error("failed reading: ", err) return } _, err = io.Copy(conn, io.LimitReader(rand.Reader, - payload_size)) + payloadSize)) if err != nil { t.Error("failed writing: ", err) return @@ -585,35 +716,38 @@ func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, }() } }() + + // Create multiple client connections var wg sync.WaitGroup for i := 0; i < clients; i++ { - tcp_client, err := net.Dial(tcp_listener.Addr().Network(), - tcp_listener.Addr().String()) + tcpClient, err := net.Dial(tcpListener.Addr().Network(), + tcpListener.Addr().String()) if err != nil { t.Fatal(err) } - ssl_client, err := newClient(tcp_client) + sslClient, err := newClient(tcpClient) if err != nil { t.Fatal(err) } wg.Add(1) go func(i int) { defer func() { - err = ssl_client.Close() + err = sslClient.Close() if err != nil { t.Error("failed closing: ", err) } wg.Done() }() for i := 0; i < loops; i++ { - _, err := io.Copy(ssl_client, io.LimitReader(rand.Reader, - payload_size)) + // Write and read data + _, err := io.Copy(sslClient, io.LimitReader(rand.Reader, + payloadSize)) if err != nil { t.Error("failed writing: ", err) return } _, err = io.Copy(ioutil.Discard, - io.LimitReader(ssl_client, payload_size)) + io.LimitReader(sslClient, payloadSize)) if err != nil { t.Error("failed reading: ", err) return @@ -625,24 +759,51 @@ func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, wg.Wait() } +// TestStdlibLotsOfConns function is used to test the situation of a large number of connections of the standard library. func TestStdlibLotsOfConns(t *testing.T) { - tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) + // Load certificate and configure TLS + tlsCert, err := tls.X509KeyPair(certBytes, keyBytes) if err != nil { t.Fatal(err) } - tls_config := &tls.Config{ - Certificates: []tls.Certificate{tls_cert}, + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, InsecureSkipVerify: true, CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} + // Execute large number of connections test + LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, + func(l net.Listener) net.Listener { + return tls.NewListener(l, tlsConfig) + }, func(c net.Conn) (net.Conn, error) { + return tls.Client(c, tlsConfig), nil + }) +} + +// TestStdlibTLSv13LotsOfConns function is used to test the situation of a large number of connections of the standard library with TLSv1.3. +func TestStdlibTLSv13LotsOfConns(t *testing.T) { + // Load certificate and configure TLS + tlsCert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + t.Fatal(err) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + // Execute large number of connections test LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { - return tls.NewListener(l, tls_config) + return tls.NewListener(l, tlsConfig) }, func(c net.Conn) (net.Conn, error) { - return tls.Client(c, tls_config), nil + return tls.Client(c, tlsConfig), nil }) } +// TestOpenSSLLotsOfConns function is used to test the situation of a large number of connections of OpenSSL. func TestOpenSSLLotsOfConns(t *testing.T) { + // Create SSL context and configure ctx, err := NewCtx() if err != nil { t.Fatal(err) @@ -667,6 +828,7 @@ func TestOpenSSLLotsOfConns(t *testing.T) { if err != nil { t.Fatal(err) } + // Execute large number of connections test LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { return NewListener(l, ctx) diff --git a/test/certs/sm2-cert.pem b/test/certs/sm2-cert.pem new file mode 100644 index 0000000..d8c56f7 --- /dev/null +++ b/test/certs/sm2-cert.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB/zCCAaSgAwIBAgIUXL0hEsZsfetdIiPFtxDgKFHRfWswCgYIKoEcz1UBg3Uw +TjEPMA0GA1UECgwGU00yLUNBMSQwIgYDVQQDDBtTTTIgY2VydGlmaWNhdGUgc2ln +bmluZyBrZXkxFTATBgkqhkiG9w0BCQEWBmNhQHNtMjAgFw0yMzA0MjUwNjUxNTZa +GA8yMDUwMDkxMDA2NTE1NlowTjEPMA0GA1UECgwGU00yLUNBMSQwIgYDVQQDDBtT +TTIgY2VydGlmaWNhdGUgc2lnbmluZyBrZXkxFTATBgkqhkiG9w0BCQEWBmNhQHNt +MjBaMBQGCCqBHM9VAYItBggqgRzPVQGCLQNCAARmJiuMDZvqstW5mi1yj931E6S5 +jxkJdqO4hE0A4n05vGGlc/K0JrCNNDPxYM97jcmbNDogre5Vh6m+9IGLtXvNo10w +WzAMBgNVHRMBAf8EAjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQU1Q7WULoMM5g/ +AyY0AxfrrvehM9cwHwYDVR0jBBgwFoAUPjscE6pUGFdHLY1byo1sbfpdspMwCgYI +KoEcz1UBg3UDSQAwRgIhALwTlkh4uvDB/S9bP0m/pxrf6D5yBjOqqojCDCyflrVY +AiEA9kzSE4ASlZDZ9HLxg4QZ/+4Wj18yrOpNEIugmGcP52w= +-----END CERTIFICATE----- diff --git a/test/certs/sm2.key b/test/certs/sm2.key new file mode 100644 index 0000000..b8aeace --- /dev/null +++ b/test/certs/sm2.key @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqBHM9VAYItBG0wawIBAQQgf9J4nk4pDYtXbngw +1epbcqxbRVYuAML+W90yJatzBRWhRANCAARmJiuMDZvqstW5mi1yj931E6S5jxkJ +dqO4hE0A4n05vGGlc/K0JrCNNDPxYM97jcmbNDogre5Vh6m+9IGLtXvN +-----END PRIVATE KEY-----