Skip to content

Commit 72a24dd

Browse files
committed
Support tls1.3 and ShangMi cipher suit
1 parent 59bd11c commit 72a24dd

File tree

9 files changed

+591
-107
lines changed

9 files changed

+591
-107
lines changed

ctx.go

+30-17
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ func newCtx(method *C.SSL_METHOD) (*Ctx, error) {
7676
type SSLVersion int
7777

7878
const (
79-
SSLv3 SSLVersion = 0x02 // Vulnerable to "POODLE" attack.
80-
TLSv1 SSLVersion = 0x03
81-
TLSv1_1 SSLVersion = 0x04
82-
TLSv1_2 SSLVersion = 0x05
83-
NTLS SSLVersion = 0x06
79+
SSLv3 SSLVersion = 0x0300 // Vulnerable to "POODLE" attack.
80+
TLSv1 SSLVersion = 0x0301
81+
TLSv1_1 SSLVersion = 0x0302
82+
TLSv1_2 SSLVersion = 0x0303
83+
TLSv1_3 SSLVersion = 0x0304
84+
NTLS SSLVersion = 0x0101
8485

8586
// AnyVersion Make sure to disable SSLv2 and SSLv3 if you use this. SSLv3 is vulnerable
8687
// to the "POODLE" attack, and SSLv2 is what, just don't even.
@@ -92,20 +93,11 @@ const (
9293
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
9394
var enableNTLS bool
9495
var method *C.SSL_METHOD
95-
switch version {
96-
case SSLv3:
97-
method = C.X_SSLv3_method()
98-
case TLSv1:
99-
method = C.X_TLSv1_method()
100-
case TLSv1_1:
101-
method = C.X_TLSv1_1_method()
102-
case TLSv1_2:
103-
method = C.X_TLSv1_2_method()
104-
case NTLS:
96+
if version == NTLS {
10597
method = C.X_NTLS_method()
10698
enableNTLS = true
107-
case AnyVersion:
108-
method = C.X_SSLv23_method()
99+
} else {
100+
method = C.TLS_method()
109101
}
110102
if method == nil {
111103
return nil, errors.New("unknown ssl/tls version")
@@ -118,6 +110,12 @@ func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
118110

119111
if enableNTLS {
120112
C.X_SSL_CTX_enable_ntls(c.ctx)
113+
} else if version == AnyVersion {
114+
C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(TLSv1))
115+
C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(TLSv1_3))
116+
} else {
117+
C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(version))
118+
C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(version))
121119
}
122120

123121
return c, nil
@@ -646,14 +644,29 @@ func (c *Ctx) SetSessionId(session_id []byte) error {
646644
func (c *Ctx) SetCipherList(list string) error {
647645
runtime.LockOSThread()
648646
defer runtime.UnlockOSThread()
647+
649648
clist := C.CString(list)
650649
defer C.free(unsafe.Pointer(clist))
650+
651651
if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 {
652652
return crypto.ErrorFromErrorQueue()
653653
}
654654
return nil
655655
}
656656

657+
func (c *Ctx) SetCipherSuites(suites string) error {
658+
runtime.LockOSThread()
659+
defer runtime.UnlockOSThread()
660+
661+
csuits := C.CString(suites)
662+
defer C.free(unsafe.Pointer(csuits))
663+
664+
if int(C.SSL_CTX_set_ciphersuites(c.ctx, csuits)) == 0 {
665+
return crypto.ErrorFromErrorQueue()
666+
}
667+
return nil
668+
}
669+
657670
type SessionCacheModes int
658671

659672
const (

examples/tlcp_client/main.go

+20-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func main() {
2828
connAddr := ""
2929
serverName := ""
3030
alpnProtocols := []string{"h2", "http/1.1"}
31-
31+
tlsVersion := ""
3232
flag.StringVar(&connAddr, "conn", "127.0.0.1:4438", "host:port")
3333
flag.StringVar(&cipherSuite, "cipher", "ECC-SM2-SM4-CBC-SM3", "cipher suite")
3434
flag.StringVar(&signCertFile, "sign_cert", "test/certs/sm2/client_sign.crt", "sign certificate file")
@@ -38,12 +38,27 @@ func main() {
3838
flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file")
3939
flag.StringVar(&serverName, "servername", "", "server name")
4040
flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols")
41-
41+
flag.StringVar(&tlsVersion, "version", "NTLS", "TLS version")
4242
flag.Parse()
4343

44-
ctx, err := ts.NewCtxWithVersion(ts.NTLS)
44+
var version ts.SSLVersion
45+
switch tlsVersion {
46+
case "TLSv1.3":
47+
version = ts.TLSv1_3
48+
case "TLSv1.2":
49+
version = ts.TLSv1_2
50+
case "TLSv1.1":
51+
version = ts.TLSv1_1
52+
case "TLSv1":
53+
version = ts.TLSv1
54+
case "NTLS":
55+
version = ts.NTLS
56+
default:
57+
version = ts.NTLS
58+
}
59+
ctx, err := ts.NewCtxWithVersion(version)
4560
if err != nil {
46-
panic(err)
61+
panic("NewCtxWithVersion failed: " + err.Error())
4762
}
4863

4964
if err := ctx.SetClientALPNProtos(alpnProtocols); err != nil {
@@ -123,7 +138,7 @@ func main() {
123138

124139
conn, err := ts.Dial("tcp", connAddr, ctx, ts.InsecureSkipHostVerification, serverName)
125140
if err != nil {
126-
panic(err)
141+
panic("connected failed" + err.Error())
127142
}
128143
defer conn.Close()
129144

examples/tlcp_server/main.go

+28-7
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ import (
1111
"bufio"
1212
"flag"
1313
"fmt"
14-
ts "github.com/tongsuo-project/tongsuo-go-sdk"
15-
"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
1614
"log"
1715
"net"
1816
"os"
1917
"path/filepath"
2018
"strings"
19+
20+
ts "github.com/tongsuo-project/tongsuo-go-sdk"
21+
"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
2122
)
2223

2324
func ReadCertificateFiles(dirPath string) (map[string]crypto.GMDoubleCertKey, error) {
@@ -89,14 +90,33 @@ func handleConn(conn net.Conn) {
8990
log.Println("Close connection")
9091
}
9192

92-
func newNTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string, alpnProtocols []string) (net.Listener, error) {
93-
94-
ctx, err := ts.NewCtxWithVersion(ts.NTLS)
93+
func newTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string, alpnProtocols []string, tlsVersion string) (net.Listener, error) {
94+
var version ts.SSLVersion
95+
switch tlsVersion {
96+
case "TLSv1.3":
97+
version = ts.TLSv1_3
98+
case "TLSv1.2":
99+
version = ts.TLSv1_2
100+
case "TLSv1.1":
101+
version = ts.TLSv1_1
102+
case "TLSv1":
103+
version = ts.TLSv1
104+
case "NTLS":
105+
version = ts.NTLS
106+
default:
107+
version = ts.TLSv1_3
108+
}
109+
ctx, err := ts.NewCtxWithVersion(version)
95110
if err != nil {
96111
log.Println(err)
97112
return nil, err
98113
}
99114

115+
err = ctx.SetCipherList("ECC-SM2-SM4-CBC-SM3")
116+
if err != nil {
117+
return nil, err
118+
}
119+
100120
if err := ctx.LoadVerifyLocations(cafile, ""); err != nil {
101121
log.Println(err)
102122
return nil, err
@@ -286,6 +306,7 @@ func main() {
286306
caFile := ""
287307
acceptAddr := ""
288308
alpnProtocols := []string{"h2", "http/1.1"}
309+
tlsVersion := ""
289310

290311
flag.StringVar(&acceptAddr, "accept", "127.0.0.1:4438", "host:port")
291312
flag.StringVar(&signCertFile, "sign_cert", "test/certs/sm2/server_sign.crt", "sign certificate file")
@@ -294,7 +315,7 @@ func main() {
294315
flag.StringVar(&encKeyFile, "enc_key", "test/certs/sm2/server_enc.key", "encrypt private key file")
295316
flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file")
296317
flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols")
297-
318+
flag.StringVar(&tlsVersion, "version", "NTLS", "TLS version")
298319
flag.Parse()
299320

300321
certFiles, err := ReadCertificateFiles("test/sni_certs")
@@ -303,7 +324,7 @@ func main() {
303324
return
304325
}
305326

306-
server, err := newNTLSServer(acceptAddr, certFiles, caFile, alpnProtocols)
327+
server, err := newTLSServer(acceptAddr, certFiles, caFile, alpnProtocols, tlsVersion)
307328
if err != nil {
308329
return
309330
}

0 commit comments

Comments
 (0)