Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TLS1.3 and ShangMi Ciphersuites. #35

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ func newCtx(method *C.SSL_METHOD) (*Ctx, error) {
type SSLVersion int

const (
SSLv3 SSLVersion = 0x02 // Vulnerable to "POODLE" attack.
TLSv1 SSLVersion = 0x03
TLSv1_1 SSLVersion = 0x04
TLSv1_2 SSLVersion = 0x05
NTLS SSLVersion = 0x06
SSLv3 SSLVersion = 0x0300 // Vulnerable to "POODLE" attack.
TLSv1 SSLVersion = 0x0301
TLSv1_1 SSLVersion = 0x0302
TLSv1_2 SSLVersion = 0x0303
TLSv1_3 SSLVersion = 0x0304
NTLS SSLVersion = 0x0101

// AnyVersion Make sure to disable SSLv2 and SSLv3 if you use this. SSLv3 is vulnerable
// to the "POODLE" attack, and SSLv2 is what, just don't even.
Expand All @@ -92,20 +93,11 @@ const (
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
var enableNTLS bool
var method *C.SSL_METHOD
switch version {
case SSLv3:
method = C.X_SSLv3_method()
case TLSv1:
method = C.X_TLSv1_method()
case TLSv1_1:
method = C.X_TLSv1_1_method()
case TLSv1_2:
method = C.X_TLSv1_2_method()
case NTLS:
if version == NTLS {
method = C.X_NTLS_method()
enableNTLS = true
case AnyVersion:
method = C.X_SSLv23_method()
} else {
method = C.TLS_method()
}
if method == nil {
return nil, errors.New("unknown ssl/tls version")
Expand All @@ -118,6 +110,12 @@ func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {

if enableNTLS {
C.X_SSL_CTX_enable_ntls(c.ctx)
} else if version == AnyVersion {
C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(TLSv1))
C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(TLSv1_3))
} else {
C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(version))
C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(version))
}

return c, nil
Expand Down Expand Up @@ -646,14 +644,29 @@ func (c *Ctx) SetSessionId(session_id []byte) error {
func (c *Ctx) SetCipherList(list string) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()

clist := C.CString(list)
defer C.free(unsafe.Pointer(clist))

if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 {
return crypto.ErrorFromErrorQueue()
}
return nil
}

func (c *Ctx) SetCipherSuites(suites string) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()

csuits := C.CString(suites)
defer C.free(unsafe.Pointer(csuits))

if int(C.SSL_CTX_set_ciphersuites(c.ctx, csuits)) == 0 {
return crypto.ErrorFromErrorQueue()
}
return nil
}

type SessionCacheModes int

const (
Expand Down
140 changes: 82 additions & 58 deletions examples/tlcp_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
)

var cipherSuites = ""

func main() {
cipherSuite := ""
signCertFile := ""
Expand All @@ -28,6 +30,7 @@ func main() {
connAddr := ""
serverName := ""
alpnProtocols := []string{"h2", "http/1.1"}
tlsVersion := ""

flag.StringVar(&connAddr, "conn", "127.0.0.1:4438", "host:port")
flag.StringVar(&cipherSuite, "cipher", "ECC-SM2-SM4-CBC-SM3", "cipher suite")
Expand All @@ -38,92 +41,113 @@ func main() {
flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file")
flag.StringVar(&serverName, "servername", "", "server name")
flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols")

flag.StringVar(&tlsVersion, "version", "NTLS", "TLS version")
flag.StringVar(&cipherSuites, "ciphersuites", "ECC-SM2-SM4-CBC-SM3", "cipherSuites")
flag.Parse()

ctx, err := ts.NewCtxWithVersion(ts.NTLS)
var version ts.SSLVersion
switch tlsVersion {
case "TLSv1.3":
version = ts.TLSv1_3
case "TLSv1.2":
version = ts.TLSv1_2
case "TLSv1.1":
version = ts.TLSv1_1
case "TLSv1":
version = ts.TLSv1
case "NTLS":
version = ts.NTLS
default:
version = ts.NTLS
}
ctx, err := ts.NewCtxWithVersion(version)
if err != nil {
panic(err)
panic("NewCtxWithVersion failed: " + err.Error())
}

if err := ctx.SetClientALPNProtos(alpnProtocols); err != nil {
panic(err)
}

if err := ctx.SetCipherList(cipherSuite); err != nil {
panic(err)
}

if signCertFile != "" {
signCertPEM, err := os.ReadFile(signCertFile)
if err != nil {
if version >= ts.TLSv1_3 {
if err := ctx.SetCipherSuites(cipherSuites); err != nil {
panic(err)
}
signCert, err := crypto.LoadCertificateFromPEM(signCertPEM)
if err != nil {
} else {
if err := ctx.SetCipherList(cipherSuites); err != nil {
panic(err)
}

if err := ctx.UseSignCertificate(signCert); err != nil {
panic(err)
if signCertFile != "" {
signCertPEM, err := os.ReadFile(signCertFile)
if err != nil {
panic(err)
}
signCert, err := crypto.LoadCertificateFromPEM(signCertPEM)
if err != nil {
panic(err)
}

if err := ctx.UseSignCertificate(signCert); err != nil {
panic(err)
}
}
}

if signKeyFile != "" {
signKeyPEM, err := os.ReadFile(signKeyFile)
if err != nil {
panic(err)
}
signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM)
if err != nil {
panic(err)
if signKeyFile != "" {
signKeyPEM, err := os.ReadFile(signKeyFile)
if err != nil {
panic(err)
}
signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM)
if err != nil {
panic(err)
}

if err := ctx.UseSignPrivateKey(signKey); err != nil {
panic(err)
}
}

if err := ctx.UseSignPrivateKey(signKey); err != nil {
panic(err)
if encCertFile != "" {
encCertPEM, err := os.ReadFile(encCertFile)
if err != nil {
panic(err)
}
encCert, err := crypto.LoadCertificateFromPEM(encCertPEM)
if err != nil {
panic(err)
}

if err := ctx.UseEncryptCertificate(encCert); err != nil {
panic(err)
}
}
}

if encCertFile != "" {
encCertPEM, err := os.ReadFile(encCertFile)
if err != nil {
panic(err)
}
encCert, err := crypto.LoadCertificateFromPEM(encCertPEM)
if err != nil {
panic(err)
}

if err := ctx.UseEncryptCertificate(encCert); err != nil {
panic(err)
}
}

if encKeyFile != "" {
encKeyPEM, err := os.ReadFile(encKeyFile)
if err != nil {
panic(err)
}
if encKeyFile != "" {
encKeyPEM, err := os.ReadFile(encKeyFile)
if err != nil {
panic(err)
}

encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM)
if err != nil {
panic(err)
}
encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM)
if err != nil {
panic(err)
}

if err := ctx.UseEncryptPrivateKey(encKey); err != nil {
panic(err)
if err := ctx.UseEncryptPrivateKey(encKey); err != nil {
panic(err)
}
}
}

if caFile != "" {
if err := ctx.LoadVerifyLocations(caFile, ""); err != nil {
panic(err)
if caFile != "" {
if err := ctx.LoadVerifyLocations(caFile, ""); err != nil {
panic(err)
}
}
}

conn, err := ts.Dial("tcp", connAddr, ctx, ts.InsecureSkipHostVerification, serverName)
if err != nil {
panic(err)
panic("connected failed" + err.Error())
}
defer conn.Close()

Expand Down
Loading
Loading