From ffe81af84d99d22c8f919478ba7b6c5e064d4668 Mon Sep 17 00:00:00 2001 From: K1 Date: Mon, 28 Oct 2024 20:28:18 +0800 Subject: [PATCH] Add golangci-lint and fix lint issues; add CI for macOS and windows CI adds golangci-lint check. Delete duplicate cgo ldflags -lcrypto Fix golangci-lint issues. - Set max length of line to 120 - Test cases add t.Parallel() - Test helpers add t.Helper() - Warp the error with fmt.Errorf() - Rename variables from C style to Golang style - Add blank line. Add CI for macOS and Windows. Fix ldflags for windows. Fix compile errors on windows. --- .github/workflows/main.yml | 142 ++++- .golangci.yml | 89 +++ build.go | 6 +- build_static.go | 6 +- conn.go | 327 ++++++----- crypto/bio.go | 236 ++++---- crypto/build.go | 2 +- crypto/build_static.go | 2 +- crypto/cert.go | 237 ++++---- crypto/cert_test.go | 201 ++++--- crypto/ciphers.go | 99 ++-- crypto/ciphers_gcm.go | 43 +- crypto/ciphers_test.go | 155 ++++-- crypto/dh.go | 18 +- crypto/dh_test.go | 16 +- crypto/dhparam.go | 15 +- crypto/digest.go | 3 +- crypto/engine.go | 6 +- crypto/hmac.go | 49 +- crypto/hmac_test.go | 22 +- crypto/hostname.go | 29 +- crypto/init.go | 43 +- crypto/init_windows.go | 2 - crypto/key.go | 436 +++++++-------- crypto/key_test.go | 465 +++++++++++----- crypto/mapping.go | 6 +- crypto/md5/md5.go | 62 +-- crypto/md5/md5_test.go | 32 +- crypto/nid.go | 382 ++++++------- crypto/sha1/sha1.go | 60 ++- crypto/sha1/sha1_test.go | 27 +- crypto/sha256/sha256.go | 47 +- crypto/sha256/sha256_test.go | 26 +- crypto/sm2/sm2.go | 82 +-- crypto/sm2/sm2_test.go | 92 ++-- crypto/sm3/sm3.go | 58 +- crypto/sm3/sm3_test.go | 80 ++- crypto/sm4/sm4.go | 131 ++--- crypto/sm4/sm4_test.go | 91 +++- ctx.go | 360 +++++++------ ctx_test.go | 16 +- examples/cert_gen/main.go | 30 +- examples/sm2_encrypt/main.go | 2 + examples/sm2_keygen/main.go | 1 + examples/sm2_sign/main.go | 1 + examples/sm2_signasn1/main.go | 2 + examples/sm3/main.go | 3 +- examples/sm4/main.go | 8 +- examples/tlcp_client/main.go | 8 + examples/tlcp_server/main.go | 21 +- http.go | 24 +- init.go | 8 +- net.go | 78 ++- ntls_test.go | 985 ++++++++++++---------------------- pem.go | 8 +- shim.c | 5 +- shim.h | 2 +- sni.c | 4 +- ssl.go | 66 +-- ssl_test.go | 564 +++++++++++-------- tickets.go | 56 +- utils/errors.go | 11 +- utils/future.go | 42 +- 63 files changed, 3429 insertions(+), 2701 deletions(-) create mode 100644 .golangci.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b3afb64..a54e710 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,7 +6,7 @@ name: CI on: [push, pull_request] jobs: - build: + golang-lint: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 @@ -22,34 +22,74 @@ jobs: gofmt-path: './' gofmt-flags: '-l -d' + - name: Go Mod + run: go mod tidy + - name: Clone Tongsuo uses: actions/checkout@v3 with: repository: Tongsuo-Project/Tongsuo - path: tongsuo + path: Tongsuo ref: 8.3-stable - name: Build Tongsuo run: | - cd tongsuo - ./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls + cd Tongsuo + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls make -j4 make install - - name: Go Mod - run: go mod tidy + - name: Golang lint + run: | + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b /usr/local/bin v1.61.0 + LD_LIBRARY_PATH=${RUNNER_TEMP}/tongsuo/lib CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" golangci-lint run ./... - - name: Go vet Check - run: LD_LIBRARY_PATH=/opt/tongsuo/lib CGO_CFLAGS="-Wall -I/opt/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L/opt/tongsuo/lib" go vet ./... - - name: Build - run: CGO_CFLAGS="-Wall -I/opt/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L/opt/tongsuo/lib" go build + build-and-test: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{matrix.os}} + steps: + - uses: actions/checkout@v3 - - name: Test - run: LD_LIBRARY_PATH=/opt/tongsuo/lib CGO_CFLAGS="-Wall -I/opt/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L/opt/tongsuo/lib" go test ./... + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 - build_static: - runs-on: ubuntu-22.04 + - name: Clone Tongsuo + uses: actions/checkout@v3 + with: + repository: Tongsuo-Project/Tongsuo + path: Tongsuo + ref: 8.3-stable + + - name: Build Tongsuo + run: | + cd Tongsuo + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls + make -j4 + make install + + - name: Build + run: CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" go build + + - name: Test on Ubuntu + run: LD_LIBRARY_PATH=${RUNNER_TEMP}/tongsuo/lib CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" go test ./... + if: matrix.os == 'ubuntu-latest' + + - name: Test on macOS + run: DYLD_LIBRARY_PATH=${RUNNER_TEMP}/tongsuo/lib CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" go test ./... + if: matrix.os == 'macos-latest' + + build-static: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{matrix.os}} steps: - uses: actions/checkout@v3 @@ -68,9 +108,77 @@ jobs: - name: Build Tongsuo Static run: | cd tongsuo - ./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls no-shared + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls no-shared make -j4 make install - - name: Test - run: LD_LIBRARY_PATH=/opt/tongsuo/lib CGO_CFLAGS="-Wall -I/opt/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L/opt/tongsuo/lib" go test ./... + - name: Build + run: CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" go build + + - name: Test on Ubuntu + run: LD_LIBRARY_PATH=${RUNNER_TEMP}/tongsuo/lib CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" go test ./... + if: matrix.os == 'ubuntu-latest' + + - name: Test on macOS + run: DYLD_LIBRARY_PATH=${RUNNER_TEMP}/tongsuo/lib CGO_CFLAGS="-Wall -I${RUNNER_TEMP}/tongsuo/include -Wno-deprecated-declarations" CGO_LDFLAGS="-L${RUNNER_TEMP}/tongsuo/lib" go test ./... + if: matrix.os == 'macos-latest' + + build-on-windows: + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Clone Tongsuo + uses: actions/checkout@v3 + with: + repository: Tongsuo-Project/Tongsuo + path: Tongsuo + ref: 8.3-stable + - uses: ilammy/msvc-dev-cmd@v1 + - uses: ilammy/setup-nasm@v1 + - uses: shogo82148/actions-setup-perl@v1 + - name: Build Tongsuo + shell: cmd + run: | + mkdir _build + cd _build + perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls + nmake /S + nmake install + working-directory: Tongsuo + + - name: Build + shell: cmd + run: | + set CGO_CFLAGS=-Wall -I%RUNNER_TEMP%\tongsuo\include -Wno-deprecated-declarations + set CGO_LDFLAGS=-L%RUNNER_TEMP%\tongsuo\lib" + go build + + - name: Set PATH for go test runtime library search + shell: perl {0} + run: | + use Actions::Core; + add_path("$ENV{RUNNER_TEMP}\\tongsuo\\bin"); + add_path("$ENV{RUNNER_TEMP}\\tongsuo\\lib"); + + - name: Test on Windows + shell: cmd + run: | + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll "C:\Program Files\MySQL\MySQL Server 8.0\bin" + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll "C:\Program Files\OpenSSL\bin" + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll C:\Windows\system32 + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll C:\Strawberry\c\bin + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll "C:\Program Files\Microsoft Service Fabric\bin\Fabric\Fabric.Code" + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll "C:\Program Files\Git\mingw64\bin" + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll c:\tools\php + copy /y %RUNNER_TEMP%\tongsuo\bin\*.dll "C:\Program Files\Amazon\AWSCLIV2" + set CGO_CFLAGS=-Wall -I%RUNNER_TEMP%\tongsuo\include -Wno-deprecated-declarations + set CGO_LDFLAGS=-L%RUNNER_TEMP%\tongsuo\lib + go env + echo %PATH% + go test ./... diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..6ed1f27 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,89 @@ +linters: + enable-all: true + disable: + - ireturn + - gochecknoinits + - exhaustruct + - nlreturn + +linters-settings: + cyclop: + max-complexity: 20 + interfacebloat: + max: 11 + lll: + line-length: 120 + funlen: + lines: 120 + statements: 80 + ignore-comments: true + depguard: + rules: + main: + allow: + - $gostd + - github.com/tongsuo-project/tongsuo-go-sdk + - github.com/tongsuo-project/tongsuo-go-sdk/crypto + - github.com/tongsuo-project/tongsuo-go-sdk/utils + +issues: + exclude: + - "variable name '(i|e|n|wg|md|ok|ca|bn|iv|ip|rv|rc|fn)' is too short for the scope of its usage" + - "parameter name '(e|r|s|ok|in|ip|iv|fn|rv)' is too short for the scope of its usage" + exclude-rules: + - path: crypto/sha1/sha1_test.go + linters: + - gosec + - path: crypto/md5/md5_test.go + linters: + - gosec + - path: conn.go + text: "Error return value of `c.flushOutputBuffer` is not checked" + - path: utils/errors.go + text: "do not define dynamic errors, use wrapped static errors instead:" + - path: ntls_test.go + text: "Error return value of `server.(Run|RunForALPN)` is not checked" + - path: ssl_test.go + text: "G402: TLS InsecureSkipVerify set true." + - path: crypto/key_test.go + text: "G101: Potential hardcoded credentials: (RSA|SSH \\(EC\\)) private key" + - path: ssl_test.go + text: "G101: Potential hardcoded credentials: (RSA|SSH \\(EC\\)) private key" + - path: ssl_test.go + text: "G402: TLS MinVersion too low." + - path: ctx.go + text: "Consider pre-allocating `protoList`" + - path: crypto/ciphers_gcm.go + text: "Magic number: (128|192|256), in detected" + - path: .*\.go + text: "dupSubExpr: suspicious identical LHS and RHS for `==` operator" + - path: crypto/sm2/sm2.go + text: "return with no blank line before" + - path: crypto/bio.go + text: "return with no blank line before" + - path: crypto/bio.go + text: "(readBioMapping|writeBioMapping) is a global variable" + - path: crypto/key_test.go + text: "Function '(TestMarshal|TestMarshalEC)' has too many statements" + - path: ctx.go + text: "sslCtxIdx is a global variable" + - path: ssl.go + text: "sslIdx is a global variable" + - path: .*_test\.go + text: "cognitive complexity (.*) of func `(TestMarshalEC|TestMarshal|TestSessionReuse|TestNTLS)` is high" + - path: .*_test\.go + text: "cyclomatic complexity (.*) of func `(TestMarshalEC|TestMarshal)` is high" + - path: .*_test\.go + text: "calculated cyclomatic complexity for function (TestMarshal|TestMarshalEC) is (.*), max is (.*)" + - path: .*_test\.go + text: "error returned from external package is unwrapped" + - path: crypto/key.go + text: "`if curve == SM2Curve` has complex nested blocks \\(complexity: 6\\)" + - path: crypto/init.go + text: "do not define dynamic errors, use wrapped static errors instead:" + - path: http.go + text: "http.go:(.*): Line contains TODO/BUG/FIXME: \"TODO: http client integration\"" + - path: ctx.go + text: "errorf: should replace errors.New" + - path: ctx.go + text: "do not define dynamic errors, use wrapped static errors instead:" diff --git a/build.go b/build.go index b3497e2..3cd9d8a 100644 --- a/build.go +++ b/build.go @@ -17,8 +17,8 @@ package tongsuogo -// #cgo linux LDFLAGS: -lssl -lcrypto -// #cgo darwin LDFLAGS: -lssl -lcrypto +// #cgo linux LDFLAGS: -lssl +// #cgo darwin LDFLAGS: -lssl // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libssl libcrypto +// #cgo windows LDFLAGS: -lssl import "C" diff --git a/build_static.go b/build_static.go index fc97d5a..d76139b 100644 --- a/build_static.go +++ b/build_static.go @@ -17,8 +17,8 @@ package tongsuogo -// #cgo linux LDFLAGS: -extldflags -static -lssl -lcrypto -// #cgo darwin LDFLAGS: -extldflags -static -lssl -lcrypto +// #cgo linux LDFLAGS: -extldflags -static -lssl +// #cgo darwin LDFLAGS: -extldflags -static -lssl // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libssl libcrypto +// #cgo windows LDFLAGS: -extldflags -static -lssl import "C" diff --git a/conn.go b/conn.go index e03afd3..7c979db 100644 --- a/conn.go +++ b/conn.go @@ -32,22 +32,22 @@ import ( ) var ( - zeroReturn = errors.New("zero return") - wantRead = errors.New("want read") - wantWrite = errors.New("want write") - tryAgain = errors.New("try again") + errZeroReturn = errors.New("zero return") + errWantRead = errors.New("want read") + errWantWrite = errors.New("want write") + errTryAgain = errors.New("try again") ) type Conn struct { *SSL - conn net.Conn - ctx *Ctx // for gc - into_ssl *crypto.ReadBio - from_ssl *crypto.WriteBio - is_shutdown bool - mtx sync.Mutex - want_read_future *utils.Future + conn net.Conn + ctx *Ctx // for gc + intoSSL *crypto.ReadBio + fromSSL *crypto.WriteBio + isShutdown bool + mtx sync.Mutex + wantReadFuture *utils.Future } type VerifyResult int @@ -105,8 +105,9 @@ func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) { defer runtime.UnlockOSThread() ssl := C.SSL_new(ctx) if ssl == nil { - return nil, crypto.ErrorFromErrorQueue() + return nil, fmt.Errorf("failed to create SSL: %w", crypto.PopError()) } + return ssl, nil } @@ -116,43 +117,44 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) { return nil, err } - into_ssl := &crypto.ReadBio{} - from_ssl := &crypto.WriteBio{} + intoSSL := &crypto.ReadBio{} + fromSSL := &crypto.WriteBio{} if ctx.GetMode()&ReleaseBuffers > 0 { - into_ssl.SetRelease(true) - from_ssl.SetRelease(true) + intoSSL.SetRelease(true) + fromSSL.SetRelease(true) } - into_ssl_cbio := into_ssl.MakeCBIO() - from_ssl_cbio := from_ssl.MakeCBIO() - if into_ssl_cbio == nil || from_ssl_cbio == nil { + intoSSLCbio := intoSSL.MakeCBIO() + fromSSLCbio := fromSSL.MakeCBIO() + if intoSSLCbio == nil || fromSSLCbio == nil { // these frees are null safe - C.BIO_free((*C.BIO)(into_ssl_cbio)) - C.BIO_free((*C.BIO)(from_ssl_cbio)) + C.BIO_free((*C.BIO)(intoSSLCbio)) + C.BIO_free((*C.BIO)(fromSSLCbio)) C.SSL_free(ssl) - return nil, errors.New("failed to allocate memory BIO") + return nil, fmt.Errorf("failed to allocate memory BIO: %w", crypto.ErrMallocFailure) } // the ssl object takes ownership of these objects now - C.SSL_set_bio(ssl, (*C.BIO)(into_ssl_cbio), (*C.BIO)(from_ssl_cbio)) + C.SSL_set_bio(ssl, (*C.BIO)(intoSSLCbio), (*C.BIO)(fromSSLCbio)) s := &SSL{ssl: ssl} C.SSL_set_ex_data(s.ssl, get_ssl_idx(), unsafe.Pointer(s.ssl)) - c := &Conn{ - SSL: s, - - conn: conn, - ctx: ctx, - into_ssl: into_ssl, - from_ssl: from_ssl} - runtime.SetFinalizer(c, func(c *Conn) { - c.into_ssl.Disconnect(into_ssl_cbio) - c.from_ssl.Disconnect(from_ssl_cbio) + con := &Conn{ + SSL: s, + conn: conn, + ctx: ctx, + intoSSL: intoSSL, + fromSSL: fromSSL, + } + runtime.SetFinalizer(con, func(c *Conn) { + c.intoSSL.Disconnect(intoSSLCbio) + c.fromSSL.Disconnect(fromSSLCbio) C.SSL_free(c.ssl) }) - return c, nil + + return con, nil } // Client wraps an existing stream connection and puts it in the connect state @@ -193,7 +195,7 @@ func (c *Conn) GetCtx() *Ctx { return c.ctx } func (c *Conn) CurrentCipher() (string, error) { p := C.X_SSL_get_cipher_name(c.ssl) if p == nil { - return "", errors.New("Session not established") + return "", fmt.Errorf("failed to get cipher: %w", crypto.ErrNoCipher) } return C.GoString(p), nil @@ -202,7 +204,7 @@ func (c *Conn) CurrentCipher() (string, error) { func (c *Conn) GetVersion() (string, error) { p := C.X_SSL_get_version(c.ssl) if p == nil { - return "", errors.New("Failed to get version") + return "", fmt.Errorf("failed to get version: %w", crypto.ErrNoVersion) } return C.GoString(p), nil @@ -210,21 +212,31 @@ func (c *Conn) GetVersion() (string, error) { func (c *Conn) fillInputBuffer() error { for { - n, err := c.into_ssl.ReadFromOnce(c.conn) + n, err := c.intoSSL.ReadFromOnce(c.conn) if n == 0 && err == nil { continue } - if err == io.EOF { - c.into_ssl.MarkEOF() + + if errors.Is(err, io.EOF) { + c.intoSSL.MarkEOF() return c.Close() } - return err + + if err != nil { + return fmt.Errorf("failed to read from connection: %w", err) + } + + return nil } } func (c *Conn) flushOutputBuffer() error { - _, err := c.from_ssl.WriteTo(c.conn) - return err + _, err := c.fromSSL.WriteTo(c.conn) + if err != nil { + return fmt.Errorf("failed to write to connection: %w", err) + } + + return nil } func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { @@ -237,27 +249,35 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { } case C.SSL_ERROR_WANT_READ: go c.flushOutputBuffer() - if c.want_read_future != nil { - want_read_future := c.want_read_future + if c.wantReadFuture != nil { + wantReadFuture := c.wantReadFuture return func() error { - _, err := want_read_future.Get() - return err + _, err := wantReadFuture.Get() + if err != nil { + return fmt.Errorf("want read future get error: %w", err) + } + return nil } } - c.want_read_future = utils.NewFuture() - want_read_future := c.want_read_future - return func() (err error) { + c.wantReadFuture = utils.NewFuture() + wantReadFuture := c.wantReadFuture + return func() error { + var err error + defer func() { c.mtx.Lock() - c.want_read_future = nil + c.wantReadFuture = nil c.mtx.Unlock() - want_read_future.Set(nil, err) + wantReadFuture.Set(nil, err) }() + err = c.fillInputBuffer() if err != nil { return err } - return tryAgain + + err = errTryAgain + return err } case C.SSL_ERROR_WANT_WRITE: return func() error { @@ -265,26 +285,26 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { if err != nil { return err } - return tryAgain + return errTryAgain } case C.SSL_ERROR_SYSCALL: var err error if C.ERR_peek_error() == 0 { switch rv { case 0: - err = errors.New("protocol-violating EOF") + err = fmt.Errorf("protocol-violating: %w", crypto.ErrUnexpectedEOF) case -1: err = errno default: - err = crypto.ErrorFromErrorQueue() + err = crypto.PopError() } } else { - err = crypto.ErrorFromErrorQueue() + err = crypto.PopError() } - return func() error { return err } + return func() error { return fmt.Errorf("syscall error: %w", err) } default: - err := crypto.ErrorFromErrorQueue() - return func() error { return err } + err := crypto.PopError() + return func() error { return fmt.Errorf("SSL error: %w", err) } } } @@ -298,23 +318,28 @@ func (c *Conn) handleError(errcb func() error) error { func (c *Conn) handshake() func() error { c.mtx.Lock() defer c.mtx.Unlock() - if c.is_shutdown { + + if c.isShutdown { return func() error { return io.ErrUnexpectedEOF } } + runtime.LockOSThread() defer runtime.UnlockOSThread() + rv, errno := C.SSL_do_handshake(c.ssl) if rv > 0 { return nil } + return c.getErrorHandler(rv, errno) } // Handshake performs an SSL handshake. If a handshake is not manually // triggered, it will run before the first I/O on the encrypted stream. func (c *Conn) Handshake() error { - err := tryAgain - for err == tryAgain { + err := errTryAgain + + for errors.Is(err, errTryAgain) { err = c.handleError(c.handshake()) } go c.flushOutputBuffer() @@ -326,12 +351,12 @@ func (c *Conn) Handshake() error { func (c *Conn) PeerCertificate() (*crypto.Certificate, error) { c.mtx.Lock() defer c.mtx.Unlock() - if c.is_shutdown { - return nil, errors.New("connection closed") + if c.isShutdown { + return nil, fmt.Errorf("connection closed: %w", crypto.ErrShutdown) } x := C.SSL_get_peer_certificate(c.ssl) if x == nil { - return nil, errors.New("no peer certificate found") + return nil, fmt.Errorf("failed to get peer cert: %w", crypto.ErrNoPeerCert) } cert := crypto.NewCertWrapper(unsafe.Pointer(x)) runtime.SetFinalizer(cert, func(cert *crypto.Certificate) { @@ -342,12 +367,10 @@ func (c *Conn) PeerCertificate() (*crypto.Certificate, error) { // loadCertificateStack loads up a stack of x509 certificates and returns them, // handling memory ownership. -func (c *Conn) loadCertificateStack(sk *C.struct_stack_st_X509) ( - rv []*crypto.Certificate) { - - sk_num := int(C.X_sk_X509_num(sk)) - rv = make([]*crypto.Certificate, 0, sk_num) - for i := 0; i < sk_num; i++ { +func (c *Conn) loadCertificateStack(sk *C.struct_stack_st_X509) []*crypto.Certificate { + skNum := int(C.X_sk_X509_num(sk)) + rv := make([]*crypto.Certificate, 0, skNum) + for i := 0; i < skNum; i++ { x := C.X_sk_X509_value(sk, C.int(i)) // ref holds on to the underlying connection memory so we don't need to // worry about incrementing refcounts manually or freeing the X509 @@ -360,15 +383,15 @@ func (c *Conn) loadCertificateStack(sk *C.struct_stack_st_X509) ( // the client side, the stack also contains the peer's certificate; if called // on the server side, the peer's certificate must be obtained separately using // PeerCertificate. -func (c *Conn) PeerCertificateChain() (rv []*crypto.Certificate, err error) { +func (c *Conn) PeerCertificateChain() ([]*crypto.Certificate, error) { c.mtx.Lock() defer c.mtx.Unlock() - if c.is_shutdown { - return nil, errors.New("connection closed") + if c.isShutdown { + return nil, fmt.Errorf("connection closed: %w", crypto.ErrShutdown) } sk := C.SSL_get_peer_cert_chain(c.ssl) if sk == nil { - return nil, errors.New("no peer certificates found") + return nil, fmt.Errorf("no peer certificates found: %w", crypto.ErrNoPeerCert) } return c.loadCertificateStack(sk), nil } @@ -381,11 +404,15 @@ type ConnectionState struct { SessionReused bool } -func (c *Conn) ConnectionState() (rv ConnectionState) { - rv.Certificate, rv.CertificateError = c.PeerCertificate() - rv.CertificateChain, rv.CertificateChainError = c.PeerCertificateChain() - rv.SessionReused = c.SessionReused() - return +func (c *Conn) ConnectionState() ConnectionState { + cert, certErr := c.PeerCertificate() + certChain, certChainErr := c.PeerCertificateChain() + sessReused := c.SessionReused() + + return ConnectionState{ + Certificate: cert, CertificateError: certErr, CertificateChain: certChain, + CertificateChainError: certChainErr, SessionReused: sessReused, + } } func (c *Conn) shutdown() func() error { @@ -408,27 +435,31 @@ func (c *Conn) shutdown() func() error { // without tickling them to close by sending a TCP_FIN packet, or // shutting down the write-side of the connection. return nil - } else { - return c.getErrorHandler(rv, errno) } + + return c.getErrorHandler(rv, errno) } func (c *Conn) shutdownLoop() error { - err := tryAgain - shutdown_tries := 0 - for err == tryAgain { - shutdown_tries = shutdown_tries + 1 + err := errTryAgain + shutdownTries := 0 + + for errors.Is(err, errTryAgain) { + shutdownTries++ err = c.handleError(c.shutdown()) if err == nil { return c.flushOutputBuffer() } - if err == tryAgain && shutdown_tries >= 2 { - return errors.New("shutdown requested a third time?") + + if errors.Is(err, errTryAgain) && shutdownTries >= 2 { + return fmt.Errorf("shutdown requested a third time? %w", crypto.ErrShutdown) } } - if err == io.ErrUnexpectedEOF { + + if errors.Is(err, io.ErrUnexpectedEOF) { err = nil } + return err } @@ -436,92 +467,105 @@ func (c *Conn) shutdownLoop() error { // connection. func (c *Conn) Close() error { c.mtx.Lock() - if c.is_shutdown { + if c.isShutdown { c.mtx.Unlock() return nil } - c.is_shutdown = true + c.isShutdown = true c.mtx.Unlock() var errs utils.ErrorGroup errs.Add(c.shutdownLoop()) errs.Add(c.conn.Close()) - return errs.Finalize() + + err := errs.Finalize() + if err != nil { + return fmt.Errorf("shutdown or close error: %w", err) + } + + return nil } -func (c *Conn) read(b []byte) (int, func() error) { - if len(b) == 0 { +func (c *Conn) read(buf []byte) (int, func() error) { + if len(buf) == 0 { return 0, nil } c.mtx.Lock() defer c.mtx.Unlock() - if c.is_shutdown { + if c.isShutdown { return 0, func() error { return io.EOF } } runtime.LockOSThread() defer runtime.UnlockOSThread() - rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) + rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&buf[0]), C.int(len(buf))) if rv > 0 { return int(rv), nil } return 0, c.getErrorHandler(rv, errno) } -// Read reads up to len(b) bytes into b. It returns the number of bytes read +// Read reads up to len(buf) bytes into buf. It returns the number of bytes read // and an error if applicable. io.EOF is returned when the caller can expect // to see no more data. -func (c *Conn) Read(b []byte) (n int, err error) { - if len(b) == 0 { +func (c *Conn) Read(buf []byte) (int, error) { + if len(buf) == 0 { return 0, nil } - err = tryAgain - for err == tryAgain { - n, errcb := c.read(b) + err := errTryAgain + + for errors.Is(err, errTryAgain) { + n, errcb := c.read(buf) err = c.handleError(errcb) if err == nil { go c.flushOutputBuffer() return n, nil } - if err == io.ErrUnexpectedEOF { + + if errors.Is(err, io.ErrUnexpectedEOF) { err = io.EOF } } return 0, err } -func (c *Conn) write(b []byte) (int, func() error) { - if len(b) == 0 { +func (c *Conn) write(buf []byte) (int, func() error) { + if len(buf) == 0 { return 0, nil } c.mtx.Lock() defer c.mtx.Unlock() - if c.is_shutdown { - err := errors.New("connection closed") + if c.isShutdown { + err := fmt.Errorf("connection closed: %w", crypto.ErrShutdown) return 0, func() error { return err } } runtime.LockOSThread() defer runtime.UnlockOSThread() - rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) + + rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&buf[0]), C.int(len(buf))) if rv > 0 { return int(rv), nil } + return 0, c.getErrorHandler(rv, errno) } // Write will encrypt the contents of b and write it to the underlying stream. // Performance will be vastly improved if the size of b is a multiple of // SSLRecordSize. -func (c *Conn) Write(b []byte) (written int, err error) { - if len(b) == 0 { +func (c *Conn) Write(data []byte) (int, error) { + if len(data) == 0 { return 0, nil } - err = tryAgain - for err == tryAgain { - n, errcb := c.write(b) + + err := errTryAgain + + for errors.Is(err, errTryAgain) { + n, errcb := c.write(data) err = c.handleError(errcb) if err == nil { return n, c.flushOutputBuffer() } } + return 0, err } @@ -532,7 +576,13 @@ func (c *Conn) VerifyHostname(host string) error { if err != nil { return err } - return cert.VerifyHostname(host) + + err = cert.VerifyHostname(host) + if err != nil { + return fmt.Errorf("failed to verify hostname: %w", err) + } + + return nil } // LocalAddr returns the underlying connection's local address @@ -547,30 +597,45 @@ func (c *Conn) RemoteAddr() net.Addr { // SetDeadline calls SetDeadline on the underlying connection. func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) + err := c.conn.SetDeadline(t) + if err != nil { + return fmt.Errorf("failed to set deadline: %w", err) + } + + return nil } // SetReadDeadline calls SetReadDeadline on the underlying connection. func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) + err := c.conn.SetReadDeadline(t) + if err != nil { + return fmt.Errorf("failed to set read deadline: %w", err) + } + + return nil } // SetWriteDeadline calls SetWriteDeadline on the underlying connection. func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + err := c.conn.SetWriteDeadline(t) + if err != nil { + return fmt.Errorf("failed to set write deadline: %w", err) + } + + return nil } func (c *Conn) UnderlyingConn() net.Conn { return c.conn } -func (c *Conn) SetTlsExtHostName(name string) error { +func (c *Conn) SetTLSExtHostName(name string) error { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) runtime.LockOSThread() defer runtime.UnlockOSThread() if C.X_SSL_set_tlsext_host_name(c.ssl, cname) == 0 { - return crypto.ErrorFromErrorQueue() + return fmt.Errorf("failed to set TLS host name: %w", crypto.PopError()) } return nil } @@ -588,9 +653,9 @@ func (c *Conn) GetSession() ([]byte, error) { defer runtime.UnlockOSThread() // get1 increases the refcount of the session, so we have to free it. - session := (*C.SSL_SESSION)(C.SSL_get1_session(c.ssl)) + session := C.SSL_get1_session(c.ssl) if session == nil { - return nil, errors.New("failed to get session") + return nil, fmt.Errorf("failed to get session: %w", crypto.ErrNoSession) } defer C.SSL_SESSION_free(session) @@ -605,7 +670,7 @@ func (c *Conn) GetSession() ([]byte, error) { tmp := buf slen2 := C.i2d_SSL_SESSION(session, &tmp) if slen != slen2 { - return nil, errors.New("session had different lengths") + return nil, fmt.Errorf("session had different lengths: %w", crypto.ErrSessionLength) } return C.GoBytes(unsafe.Pointer(buf), slen), nil @@ -616,22 +681,22 @@ func (c *Conn) setSession(session []byte) error { defer runtime.UnlockOSThread() if len(session) == 0 { - return fmt.Errorf("session is empty") + return fmt.Errorf("session is empty: %w", crypto.ErrEmptySession) } cSession := C.CBytes(session) - defer C.free(unsafe.Pointer(cSession)) + defer C.free(cSession) ptr := (*C.uchar)(cSession) - s := C.d2i_SSL_SESSION(nil, (**C.uchar)(&ptr), C.long(len(session))) - if s == nil { - return fmt.Errorf("unable to load session: %s", crypto.ErrorFromErrorQueue()) + sess := C.d2i_SSL_SESSION(nil, &ptr, C.long(len(session))) + if sess == nil { + return fmt.Errorf("unable to load session: %w", crypto.PopError()) } - defer C.SSL_SESSION_free(s) + defer C.SSL_SESSION_free(sess) - ret := C.SSL_set_session(c.ssl, s) + ret := C.SSL_set_session(c.ssl, sess) if ret != 1 { - return fmt.Errorf("unable to set session: %s", crypto.ErrorFromErrorQueue()) + return fmt.Errorf("unable to set session: %w", crypto.PopError()) } return nil } @@ -642,7 +707,7 @@ func (c *Conn) GetALPNNegotiated() (string, error) { var protoLen C.uint C.SSL_get0_alpn_selected(c.ssl, &proto, &protoLen) if protoLen == 0 { - return "", fmt.Errorf("no ALPN protocol negotiated") + return "", fmt.Errorf("no ALPN protocol negotiated: %w", crypto.ErrNoALPN) } return C.GoStringN((*C.char)(unsafe.Pointer(proto)), C.int(protoLen)), nil } diff --git a/crypto/bio.go b/crypto/bio.go index 7323830..25eecb4 100644 --- a/crypto/bio.go +++ b/crypto/bio.go @@ -18,9 +18,8 @@ package crypto import "C" import ( - "errors" + "fmt" "io" - "reflect" "sync" "unsafe" ) @@ -30,12 +29,7 @@ const ( ) func nonCopyGoBytes(ptr uintptr, length int) []byte { - var slice []byte - header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - header.Cap = length - header.Len = length - header.Data = ptr - return slice + return unsafe.Slice((*byte)(unsafe.Pointer(ptr)), length) } func nonCopyCString(data *C.char, size C.int) []byte { @@ -45,14 +39,15 @@ func nonCopyCString(data *C.char, size C.int) []byte { var writeBioMapping = newMapping() type WriteBio struct { - data_mtx sync.Mutex - op_mtx sync.Mutex - buf []byte - release_buffers bool + dataMtx sync.Mutex + opMtx sync.Mutex + buf []byte + releaseBuffers bool } func loadWritePtr(b *C.BIO) *WriteBio { t := token(C.X_BIO_get_data(b)) + return (*WriteBio)(writeBioMapping.Get(t)) } @@ -65,41 +60,50 @@ func bioSetRetryRead(b *C.BIO) { } //export go_write_bio_write -func go_write_bio_write(b *C.BIO, data *C.char, size C.int) (rc C.int) { +func go_write_bio_write(bio *C.BIO, data *C.char, size C.int) C.int { + var rc C.int + defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: writeBioWrite panic'd: %v", err) + // logger.Critf("openssl: writeBioWrite panic'd: %v", err) rc = -1 } }() - ptr := loadWritePtr(b) + ptr := loadWritePtr(bio) if ptr == nil || data == nil || size < 0 { return -1 } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() - bioClearRetryFlags(b) + ptr.dataMtx.Lock() + defer ptr.dataMtx.Unlock() + bioClearRetryFlags(bio) ptr.buf = append(ptr.buf, nonCopyCString(data, size)...) - return size + rc = size + + return rc } //export go_write_bio_ctrl -func go_write_bio_ctrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) ( - rc C.long) { +func go_write_bio_ctrl(bio *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long { + _, _ = arg1, arg2 // unused + + var rc C.long + defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: writeBioCtrl panic'd: %v", err) + // logger.Critf("openssl: writeBioCtrl panic'd: %v", err) rc = -1 } }() switch cmd { case C.BIO_CTRL_WPENDING: - return writeBioPending(b) + rc = writeBioPending(bio) case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: - return 1 + rc = 1 default: - return 0 + rc = 0 } + + return rc } func writeBioPending(b *C.BIO) C.long { @@ -107,64 +111,66 @@ func writeBioPending(b *C.BIO) C.long { if ptr == nil { return 0 } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() + ptr.dataMtx.Lock() + defer ptr.dataMtx.Unlock() + return C.long(len(ptr.buf)) } -func (b *WriteBio) WriteTo(w io.Writer) (rv int64, err error) { - b.op_mtx.Lock() - defer b.op_mtx.Unlock() +func (bio *WriteBio) WriteTo(writer io.Writer) (int64, error) { + bio.opMtx.Lock() + defer bio.opMtx.Unlock() // write whatever data we currently have - b.data_mtx.Lock() - data := b.buf - b.data_mtx.Unlock() + bio.dataMtx.Lock() + data := bio.buf + bio.dataMtx.Unlock() if len(data) == 0 { return 0, nil } - n, err := w.Write(data) + n, err := writer.Write(data) // subtract however much data we wrote from the buffer - b.data_mtx.Lock() - b.buf = b.buf[:copy(b.buf, b.buf[n:])] - if b.release_buffers && len(b.buf) == 0 { - b.buf = nil + bio.dataMtx.Lock() + bio.buf = bio.buf[:copy(bio.buf, bio.buf[n:])] + if bio.releaseBuffers && len(bio.buf) == 0 { + bio.buf = nil } - b.data_mtx.Unlock() + bio.dataMtx.Unlock() return int64(n), err } -func (b *WriteBio) SetRelease(flag bool) { - b.data_mtx.Lock() - defer b.data_mtx.Unlock() - b.release_buffers = flag +func (bio *WriteBio) SetRelease(flag bool) { + bio.dataMtx.Lock() + defer bio.dataMtx.Unlock() + bio.releaseBuffers = flag } -func (self *WriteBio) Disconnect(b *C.BIO) { - if loadWritePtr(b) == self { +func (bio *WriteBio) Disconnect(b *C.BIO) { + if loadWritePtr(b) == bio { writeBioMapping.Del(token(C.X_BIO_get_data(b))) C.X_BIO_set_data(b, nil) } } -func (b *WriteBio) MakeCBIO() *C.BIO { +func (bio *WriteBio) MakeCBIO() *C.BIO { rv := C.X_BIO_new_write_bio() - token := writeBioMapping.Add(unsafe.Pointer(b)) + token := writeBioMapping.Add(unsafe.Pointer(bio)) C.X_BIO_set_data(rv, unsafe.Pointer(token)) + return rv } var readBioMapping = newMapping() type ReadBio struct { - data_mtx sync.Mutex - op_mtx sync.Mutex - buf []byte - eof bool - release_buffers bool + dataMtx sync.Mutex + opMtx sync.Mutex + buf []byte + eof bool + releaseBuffers bool } func loadReadPtr(b *C.BIO) *ReadBio { @@ -172,56 +178,61 @@ func loadReadPtr(b *C.BIO) *ReadBio { } //export go_read_bio_read -func go_read_bio_read(b *C.BIO, data *C.char, size C.int) (rc C.int) { +func go_read_bio_read(bio *C.BIO, data *C.char, size C.int) C.int { + rc := 0 + defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: go_read_bio_read panic'd: %v", err) + // logger.Critf("openssl: go_read_bio_read panic'd: %v", err) rc = -1 } }() - ptr := loadReadPtr(b) + ptr := loadReadPtr(bio) if ptr == nil || size < 0 { return -1 } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() - bioClearRetryFlags(b) + ptr.dataMtx.Lock() + defer ptr.dataMtx.Unlock() + bioClearRetryFlags(bio) if len(ptr.buf) == 0 { if ptr.eof { return 0 } - bioSetRetryRead(b) + bioSetRetryRead(bio) return -1 } if size == 0 || data == nil { return C.int(len(ptr.buf)) } - n := copy(nonCopyCString(data, size), ptr.buf) - ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])] - if ptr.release_buffers && len(ptr.buf) == 0 { + rc = copy(nonCopyCString(data, size), ptr.buf) + ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[rc:])] + if ptr.releaseBuffers && len(ptr.buf) == 0 { ptr.buf = nil } - return C.int(n) + return C.int(rc) } //export go_read_bio_ctrl -func go_read_bio_ctrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) ( - rc C.long) { +func go_read_bio_ctrl(bio *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long { + _, _ = arg1, arg2 // unused + var rc C.long defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: readBioCtrl panic'd: %v", err) + // logger.Critf("openssl: readBioCtrl panic'd: %v", err) rc = -1 } }() switch cmd { case C.BIO_CTRL_PENDING: - return readBioPending(b) + rc = readBioPending(bio) case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: - return 1 + rc = 1 default: - return 0 + rc = 0 } + + return rc } func readBioPending(b *C.BIO) C.long { @@ -229,89 +240,98 @@ func readBioPending(b *C.BIO) C.long { if ptr == nil { return 0 } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() + ptr.dataMtx.Lock() + defer ptr.dataMtx.Unlock() return C.long(len(ptr.buf)) } -func (b *ReadBio) SetRelease(flag bool) { - b.data_mtx.Lock() - defer b.data_mtx.Unlock() - b.release_buffers = flag +func (bio *ReadBio) SetRelease(flag bool) { + bio.dataMtx.Lock() + defer bio.dataMtx.Unlock() + bio.releaseBuffers = flag } -func (b *ReadBio) ReadFromOnce(r io.Reader) (n int, err error) { - b.op_mtx.Lock() - defer b.op_mtx.Unlock() +func (bio *ReadBio) ReadFromOnce(r io.Reader) (int, error) { + bio.opMtx.Lock() + defer bio.opMtx.Unlock() // make sure we have a destination that fits at least one SSL record - b.data_mtx.Lock() - if cap(b.buf) < len(b.buf)+SSLRecordSize { - new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize) - copy(new_buf, b.buf) - b.buf = new_buf + bio.dataMtx.Lock() + if cap(bio.buf) < len(bio.buf)+SSLRecordSize { + newBuf := make([]byte, len(bio.buf), len(bio.buf)+SSLRecordSize) + copy(newBuf, bio.buf) + bio.buf = newBuf } - dst := b.buf[len(b.buf):cap(b.buf)] - dst_slice := b.buf - b.data_mtx.Unlock() - n, err = r.Read(dst) - b.data_mtx.Lock() - defer b.data_mtx.Unlock() + dst := bio.buf[len(bio.buf):cap(bio.buf)] + dstSlice := bio.buf + bio.dataMtx.Unlock() + + n, err := r.Read(dst) + bio.dataMtx.Lock() + defer bio.dataMtx.Unlock() if n > 0 { - if len(dst_slice) != len(b.buf) { + if len(dstSlice) != len(bio.buf) { // someone shrunk the buffer, so we read in too far ahead and we // need to slide backwards - copy(b.buf[len(b.buf):len(b.buf)+n], dst) + copy(bio.buf[len(bio.buf):len(bio.buf)+n], dst) } - b.buf = b.buf[:len(b.buf)+n] + bio.buf = bio.buf[:len(bio.buf)+n] + } + + if err != nil { + return n, fmt.Errorf("read from once error: %w", err) } - return n, err + + return n, nil } -func (b *ReadBio) MakeCBIO() *C.BIO { +func (bio *ReadBio) MakeCBIO() *C.BIO { rv := C.X_BIO_new_read_bio() - token := readBioMapping.Add(unsafe.Pointer(b)) + token := readBioMapping.Add(unsafe.Pointer(bio)) C.X_BIO_set_data(rv, unsafe.Pointer(token)) return rv } -func (self *ReadBio) Disconnect(b *C.BIO) { - if loadReadPtr(b) == self { +func (bio *ReadBio) Disconnect(b *C.BIO) { + if loadReadPtr(b) == bio { readBioMapping.Del(token(C.X_BIO_get_data(b))) C.X_BIO_set_data(b, nil) } } -func (b *ReadBio) MarkEOF() { - b.data_mtx.Lock() - defer b.data_mtx.Unlock() - b.eof = true +func (bio *ReadBio) MarkEOF() { + bio.dataMtx.Lock() + defer bio.dataMtx.Unlock() + bio.eof = true } type anyBio C.BIO func asAnyBio(b *C.BIO) *anyBio { return (*anyBio)(b) } -func (b *anyBio) Read(buf []byte) (n int, err error) { +func (bio *anyBio) Read(buf []byte) (int, error) { if len(buf) == 0 { return 0, nil } - n = int(C.X_BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf)))) + n := int(C.X_BIO_read((*C.BIO)(bio), unsafe.Pointer(&buf[0]), C.int(len(buf)))) if n <= 0 { return 0, io.EOF } return n, nil } -func (b *anyBio) Write(buf []byte) (written int, err error) { +func (bio *anyBio) Write(buf []byte) (int, error) { if len(buf) == 0 { return 0, nil } - n := int(C.X_BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]), + ret := int(C.X_BIO_write((*C.BIO)(bio), unsafe.Pointer(&buf[0]), C.int(len(buf)))) - if n != len(buf) { - return n, errors.New("BIO write failed") + if ret < 0 { + return 0, fmt.Errorf("BIO write failed: %w", PopError()) } - return n, nil + if ret < len(buf) { + return ret, fmt.Errorf("BIO write trucated: %w", ErrPartialWrite) + } + return ret, nil } diff --git a/crypto/build.go b/crypto/build.go index 24972ec..aa37997 100644 --- a/crypto/build.go +++ b/crypto/build.go @@ -13,5 +13,5 @@ package crypto // #cgo linux LDFLAGS: -lcrypto // #cgo darwin LDFLAGS: -lcrypto // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto +// #cgo windows LDFLAGS: -lcrypto import "C" diff --git a/crypto/build_static.go b/crypto/build_static.go index 7433e6a..cb3957b 100644 --- a/crypto/build_static.go +++ b/crypto/build_static.go @@ -13,5 +13,5 @@ package crypto // #cgo linux LDFLAGS: -extldflags -static -lcrypto // #cgo darwin LDFLAGS: -lcrypto // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto +// #cgo windows LDFLAGS: -extldflags -static -lcrypto import "C" diff --git a/crypto/cert.go b/crypto/cert.go index 2ed1063..b49df89 100644 --- a/crypto/cert.go +++ b/crypto/cert.go @@ -18,10 +18,8 @@ package crypto import "C" import ( - "errors" "fmt" "io" - "io/ioutil" "math/big" "os" "runtime" @@ -29,23 +27,23 @@ import ( "unsafe" ) -type EVP_MD int +type MDAlgo int const ( - EVP_NULL EVP_MD = iota - EVP_MD5 EVP_MD = iota - EVP_MD4 EVP_MD = iota - EVP_SHA EVP_MD = iota - EVP_SHA1 EVP_MD = iota - EVP_DSS EVP_MD = iota - EVP_DSS1 EVP_MD = iota - EVP_MDC2 EVP_MD = iota - EVP_RIPEMD160 EVP_MD = iota - EVP_SHA224 EVP_MD = iota - EVP_SHA256 EVP_MD = iota - EVP_SHA384 EVP_MD = iota - EVP_SHA512 EVP_MD = iota - EVP_SM3 EVP_MD = iota + MDNull MDAlgo = iota + MDMD5 MDAlgo = iota + MDMD4 MDAlgo = iota + MDSHA MDAlgo = iota + MDSHA1 MDAlgo = iota + MDDSS MDAlgo = iota + MDDSS1 MDAlgo = iota + MDMDC2 MDAlgo = iota + MDRipemd160 MDAlgo = iota + MDSHA224 MDAlgo = iota + MDSHA256 MDAlgo = iota + MDSHA384 MDAlgo = iota + MDSHA512 MDAlgo = iota + MDSM3 MDAlgo = iota ) type GMDoubleCertKey struct { @@ -55,14 +53,14 @@ type GMDoubleCertKey struct { EncKeyFile string } -// X509_Version represents a version on a x509 certificate. -type X509_Version int +// X509Version represents a version on a x509 certificate. +type X509Version int // Specify constants for x509 versions because the standard states that they // are represented internally as one lower than the common version name. const ( - X509_V1 X509_Version = 0 - X509_V3 X509_Version = 2 + X509V1 X509Version = 0 + X509V3 X509Version = 2 ) type Certificate struct { @@ -88,35 +86,40 @@ type Name struct { func NewCertWrapper(x unsafe.Pointer, ref ...interface{}) *Certificate { if len(ref) > 0 { return &Certificate{x: (*C.X509)(x), ref: ref[0]} - } else { - return &Certificate{x: (*C.X509)(x)} } + + return &Certificate{x: (*C.X509)(x)} } // NewName allocate and return a new Name object. func NewName() (*Name, error) { n := C.X509_NAME_new() if n == nil { - return nil, errors.New("could not create x509 name") + return nil, fmt.Errorf("could not create x509 name: %w", ErrMallocFailure) } name := &Name{name: n} runtime.SetFinalizer(name, func(n *Name) { C.X509_NAME_free(n.name) }) + return name, nil } // AddTextEntry appends a text entry to an X509 NAME. func (n *Name) AddTextEntry(field, value string) error { cfield := C.CString(field) + defer C.free(unsafe.Pointer(cfield)) + cvalue := (*C.uchar)(unsafe.Pointer(C.CString(value))) + defer C.free(unsafe.Pointer(cvalue)) - ret := C.X509_NAME_add_entry_by_txt( - n.name, cfield, C.MBSTRING_ASC, cvalue, -1, -1, 0) + + ret := C.X509_NAME_add_entry_by_txt(n.name, cfield, C.MBSTRING_ASC, cvalue, -1, -1, 0) if ret != 1 { - return errors.New("failed to add x509 name text entry") + return fmt.Errorf("failed to add x509 name text entry: %w", PopError()) } + return nil } @@ -132,7 +135,7 @@ func (n *Name) AddTextEntries(entries map[string]string) error { // GetEntry returns a name entry based on NID. If no entry, then ("", false) is // returned. -func (n *Name) GetEntry(nid NID) (entry string, ok bool) { +func (n *Name) GetEntry(nid NID) (string, bool) { entrylen := C.X509_NAME_get_text_by_NID(n.name, C.int(nid), nil, 0) if entrylen == -1 { return "", false @@ -146,14 +149,14 @@ func (n *Name) GetEntry(nid NID) (entry string, ok bool) { // NewCertificate generates a basic certificate based // on the provided CertificateInfo struct func NewCertificate(info *CertificateInfo, key PublicKey) (*Certificate, error) { - c := &Certificate{x: C.X509_new()} - runtime.SetFinalizer(c, func(c *Certificate) { + cert := &Certificate{x: C.X509_new()} + runtime.SetFinalizer(cert, func(c *Certificate) { C.X509_free(c.x) }) - if err := c.SetVersion(X509_V3); err != nil { + if err := cert.SetVersion(X509V3); err != nil { return nil, err } - name, err := c.GetSubjectName() + name, err := cert.GetSubjectName() if err != nil { return nil, err } @@ -166,22 +169,22 @@ func NewCertificate(info *CertificateInfo, key PublicKey) (*Certificate, error) return nil, err } // self-issue for now - if err := c.SetIssuerName(name); err != nil { + if err := cert.SetIssuerName(name); err != nil { return nil, err } - if err := c.SetSerial(info.Serial); err != nil { + if err := cert.SetSerial(info.Serial); err != nil { return nil, err } - if err := c.SetIssueDate(info.Issued); err != nil { + if err := cert.SetIssueDate(info.Issued); err != nil { return nil, err } - if err := c.SetExpireDate(info.Expires); err != nil { + if err := cert.SetExpireDate(info.Expires); err != nil { return nil, err } - if err := c.SetPubKey(key); err != nil { + if err := cert.SetPubKey(key); err != nil { return nil, err } - return c, nil + return cert, nil } func (c *Certificate) GetCert() *C.X509 { @@ -191,7 +194,7 @@ func (c *Certificate) GetCert() *C.X509 { func (c *Certificate) GetSubjectName() (*Name, error) { n := C.X509_get_subject_name(c.x) if n == nil { - return nil, errors.New("failed to get subject name") + return nil, fmt.Errorf("failed to get subject name: %w", ErrNilParameter) } return &Name{name: n}, nil } @@ -199,14 +202,14 @@ func (c *Certificate) GetSubjectName() (*Name, error) { func (c *Certificate) GetIssuerName() (*Name, error) { n := C.X509_get_issuer_name(c.x) if n == nil { - return nil, errors.New("failed to get issuer name") + return nil, fmt.Errorf("failed to get issuer name: %w", ErrNilParameter) } return &Name{name: n}, nil } func (c *Certificate) SetSubjectName(name *Name) error { if C.X509_set_subject_name(c.x, name.name) != 1 { - return errors.New("failed to set subject name") + return fmt.Errorf("failed to set subject name: %w", PopError()) } return nil } @@ -230,7 +233,7 @@ func (c *Certificate) SetIssuer(issuer *Certificate) error { // Use SetIssuer instead, if possible. func (c *Certificate) SetIssuerName(name *Name) error { if C.X509_set_issuer_name(c.x, name.name) != 1 { - return errors.New("failed to set subject name") + return fmt.Errorf("failed to set subject name: %w", PopError()) } return nil } @@ -244,13 +247,13 @@ func (c *Certificate) SetSerial(serial *big.Int) error { serialBytes := serial.Bytes() if bn = C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&serialBytes[0])), C.int(len(serialBytes)), bn); bn == nil { - return errors.New("failed to set serial") + return fmt.Errorf("failed to set serial: %w", PopError()) } if sno = C.BN_to_ASN1_INTEGER(bn, sno); sno == nil { - return errors.New("failed to set serial") + return fmt.Errorf("failed to set serial: %w", PopError()) } if C.X509_set_serialNumber(c.x, sno) != 1 { - return errors.New("failed to set serial") + return fmt.Errorf("failed to set serial: %w", PopError()) } return nil } @@ -260,7 +263,7 @@ func (c *Certificate) SetIssueDate(when time.Duration) error { offset := C.long(when / time.Second) result := C.X509_gmtime_adj(C.X_X509_get0_notBefore(c.x), offset) if result == nil { - return errors.New("failed to set issue date") + return fmt.Errorf("failed to set issue date: %w", PopError()) } return nil } @@ -270,7 +273,7 @@ func (c *Certificate) SetExpireDate(when time.Duration) error { offset := C.long(when / time.Second) result := C.X509_gmtime_adj(C.X_X509_get0_notAfter(c.x), offset) if result == nil { - return errors.New("failed to set expire date") + return fmt.Errorf("failed to set expire date: %w", PopError()) } return nil } @@ -279,30 +282,29 @@ func (c *Certificate) SetExpireDate(when time.Duration) error { func (c *Certificate) SetPubKey(pubKey PublicKey) error { c.pubKey = pubKey if C.X509_set_pubkey(c.x, pubKey.EvpPKey()) != 1 { - return errors.New("failed to set public key") + return fmt.Errorf("failed to set public key: %w", PopError()) } return nil } // Sign a certificate using a private key and a digest name. // Accepted digest names are 'sm3', 'sha256', 'sha384', and 'sha512'. -func (c *Certificate) Sign(privKey PrivateKey, digest EVP_MD) error { +func (c *Certificate) Sign(privKey PrivateKey, digest MDAlgo) error { switch digest { - case EVP_SM3: - case EVP_SHA256: - case EVP_SHA384: - case EVP_SHA512: + case MDSM3: + case MDSHA256: + case MDSHA384: + case MDSHA512: default: - return errors.New("Unsupported digest" + - "You're probably looking for 'EVP_SHA256' or 'EVP_SHA512'.") + return ErrUnsupportedDigest } return c.insecureSign(privKey, digest) } -func (c *Certificate) insecureSign(privKey PrivateKey, digest EVP_MD) error { +func (c *Certificate) insecureSign(privKey PrivateKey, digest MDAlgo) error { var md *C.EVP_MD = getDigestFunction(digest) if C.X509_sign(c.x, privKey.EvpPKey(), md) <= 0 { - return errors.New("failed to sign certificate") + return fmt.Errorf("failed to sign certificate: %w", PopError()) } return nil } @@ -311,13 +313,13 @@ func (c *Certificate) insecureSign(privKey PrivateKey, digest EVP_MD) error { // Extension constants are NID_* as found in openssl. func (c *Certificate) AddExtension(nid NID, value string) error { if c.x == nil { - return errors.New("certificate is nil") + return fmt.Errorf("certificate is nil: %w", ErrNilParameter) } issuer := c if c.Issuer != nil { if c.Issuer.x == nil { - return errors.New("issuer certificate is nil") + return fmt.Errorf("issuer certificate is nil: %w", ErrNilParameter) } issuer = c.Issuer } @@ -328,33 +330,26 @@ func (c *Certificate) AddExtension(nid NID, value string) error { var ctx C.X509V3_CTX C.X509V3_set_ctx(&ctx, c.x, issuer.x, nil, nil, 0) - ex := C.X509V3_EXT_conf_nid(nil, &ctx, C.int(nid), cValue) - if ex == nil { - return fmt.Errorf("failed to create x509v3 extension: %s", getOpenSSLError()) + ext := C.X509V3_EXT_conf_nid(nil, &ctx, C.int(nid), cValue) + if ext == nil { + return fmt.Errorf("failed to create x509v3 extension: %w", PopError()) } - defer C.X509_EXTENSION_free(ex) + defer C.X509_EXTENSION_free(ext) - if C.X509_add_ext(c.x, ex, -1) <= 0 { - return fmt.Errorf("failed to add x509v3 extension: %s", getOpenSSLError()) + if C.X509_add_ext(c.x, ext, -1) <= 0 { + return fmt.Errorf("failed to add x509v3 extension: %w", PopError()) } return nil } -// getOpenSSLError Get the last error from the OpenSSL error queue. -func getOpenSSLError() string { - var errStrBuf [120]byte - C.ERR_error_string_n(C.ERR_get_error(), (*C.char)(unsafe.Pointer(&errStrBuf[0])), 120) - return string(errStrBuf[:]) -} - // helper function to validate extension input func validateExtensionInput(nid NID, value string) error { if nid <= 0 { - return errors.New("invalid NID") + return ErrInvalidNid } if value == "" { - return errors.New("empty extension value") + return ErrEmptyExtensionValue } return nil } @@ -362,11 +357,12 @@ func validateExtensionInput(nid NID, value string) error { // AddExtensions Wraps AddExtension using a map of NID to text extension. // Will return without finishing if it encounters an error. func (c *Certificate) AddExtensions(extensions map[NID]string) error { - targetNid := NID_authority_key_identifier + targetNid := NidAuthorityKeyIdentifier found := false for nid, value := range extensions { - if nid == NID_authority_key_identifier { + if nid == NidAuthorityKeyIdentifier { found = true + continue } if err := c.AddExtension(nid, value); err != nil { @@ -384,18 +380,18 @@ func (c *Certificate) AddExtensions(extensions map[NID]string) error { } // LoadCertificateFromPEM loads an X509 certificate from a PEM-encoded block. -func LoadCertificateFromPEM(pem_block []byte) (*Certificate, error) { - if len(pem_block) == 0 { - return nil, errors.New("empty pem block") +func LoadCertificateFromPEM(pemBlock []byte) (*Certificate, error) { + if len(pemBlock) == 0 { + return nil, ErrNoCert } runtime.LockOSThread() defer runtime.UnlockOSThread() - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) + + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pemBlock[0]), C.int(len(pemBlock))) cert := C.PEM_read_bio_X509(bio, nil, nil, nil) C.BIO_free(bio) if cert == nil { - return nil, ErrorFromErrorQueue() + return nil, PopError() } x := &Certificate{x: cert} runtime.SetFinalizer(x, func(x *Certificate) { @@ -405,23 +401,30 @@ func LoadCertificateFromPEM(pem_block []byte) (*Certificate, error) { } // MarshalPEM converts the X509 certificate to PEM-encoded format -func (c *Certificate) MarshalPEM() (pem_block []byte, err error) { +func (c *Certificate) MarshalPEM() ([]byte, error) { bio := C.BIO_new(C.BIO_s_mem()) if bio == nil { - return nil, errors.New("failed to allocate memory BIO") + return nil, ErrMallocFailure } + defer C.BIO_free(bio) if int(C.PEM_write_bio_X509(bio, c.x)) != 1 { - return nil, errors.New("failed dumping certificate") + return nil, fmt.Errorf("failed to write certificate: %w", PopError()) } - return ioutil.ReadAll(asAnyBio(bio)) + + data, err := io.ReadAll(asAnyBio(bio)) + if err != nil { + return nil, fmt.Errorf("failed to read certificate: %w", err) + } + + return data, nil } // PublicKey returns the public key embedded in the X509 certificate. func (c *Certificate) PublicKey() (PublicKey, error) { pkey := C.X509_get_pubkey(c.x) if pkey == nil { - return nil, errors.New("no public key found") + return nil, ErrNoPubKey } key := &pKey{key: pkey} runtime.SetFinalizer(key, func(key *pKey) { @@ -431,55 +434,57 @@ func (c *Certificate) PublicKey() (PublicKey, error) { } // GetSerialNumberHex returns the certificate's serial number in hex format -func (c *Certificate) GetSerialNumberHex() (serial string) { - asn1_i := C.X509_get_serialNumber(c.x) - bignum := C.ASN1_INTEGER_to_BN(asn1_i, nil) +func (c *Certificate) GetSerialNumberHex() string { + asn1Num := C.X509_get_serialNumber(c.x) + bignum := C.ASN1_INTEGER_to_BN(asn1Num, nil) + defer C.BN_free(bignum) + hex := C.BN_bn2hex(bignum) - serial = C.GoString(hex) - C.BN_free(bignum) - C.X_OPENSSL_free(unsafe.Pointer(hex)) - return + defer C.X_OPENSSL_free(unsafe.Pointer(hex)) + + serial := C.GoString(hex) + + return serial } // GetVersion returns the X509 version of the certificate. -func (c *Certificate) GetVersion() X509_Version { - return X509_Version(C.X_X509_get_version(c.x)) +func (c *Certificate) GetVersion() X509Version { + return X509Version(C.X_X509_get_version(c.x)) } // SetVersion sets the X509 version of the certificate. -func (c *Certificate) SetVersion(version X509_Version) error { +func (c *Certificate) SetVersion(version X509Version) error { cvers := C.long(version) if C.X_X509_set_version(c.x, cvers) != 1 { - return errors.New("failed to set certificate version") + return fmt.Errorf("failed to set certificate version: %w", PopError()) } return nil } -func getDigestFunction(digest EVP_MD) (md *C.EVP_MD) { +func getDigestFunction(digest MDAlgo) *C.EVP_MD { + var md *C.EVP_MD switch digest { - // please don't use these digest functions - case EVP_NULL: + case MDNull: md = C.X_EVP_md_null() - case EVP_MD5: + case MDMD5: md = C.X_EVP_md5() - case EVP_SHA: + case MDSHA: md = C.X_EVP_sha() - case EVP_SHA1: + case MDSHA1: md = C.X_EVP_sha1() - case EVP_DSS: + case MDDSS: md = C.X_EVP_dss() - case EVP_DSS1: + case MDDSS1: md = C.X_EVP_dss1() - case EVP_SHA224: + case MDSHA224: md = C.X_EVP_sha224() - // you actually want one of these - case EVP_SHA256: + case MDSHA256: md = C.X_EVP_sha256() - case EVP_SHA384: + case MDSHA384: md = C.X_EVP_sha384() - case EVP_SHA512: + case MDSHA512: md = C.X_EVP_sha512() - case EVP_SM3: + case MDSM3: md = C.X_EVP_sm3() } return md @@ -489,13 +494,13 @@ func getDigestFunction(digest EVP_MD) (md *C.EVP_MD) { func LoadPEMFromFile(filename string) ([]byte, error) { file, err := os.Open(filename) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open file: %w", err) } defer file.Close() pemBlock, err := io.ReadAll(file) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read file: %w", err) } return pemBlock, nil @@ -505,13 +510,13 @@ func LoadPEMFromFile(filename string) ([]byte, error) { func SavePEMToFile(pemBlock []byte, filename string) error { file, err := os.Create(filename) if err != nil { - return err + return fmt.Errorf("failed to create file: %w", err) } defer file.Close() _, err = file.Write(pemBlock) if err != nil { - return err + return fmt.Errorf("failed to write: %w", err) } return nil diff --git a/crypto/cert_test.go b/crypto/cert_test.go index fdae813..5595ec1 100644 --- a/crypto/cert_test.go +++ b/crypto/cert_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package crypto +package crypto_test import ( "math/big" @@ -20,14 +20,19 @@ import ( "path/filepath" "testing" "time" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) func TestCertGenerate(t *testing.T) { - key, err := GenerateRSAKey(768) + t.Parallel() + + key, err := crypto.GenerateRSAKey(768) if err != nil { t.Fatal(err) } - info := &CertificateInfo{ + + info := &crypto.CertificateInfo{ Serial: big.NewInt(int64(1)), Issued: 0, Expires: 24 * time.Hour, @@ -35,21 +40,26 @@ func TestCertGenerate(t *testing.T) { Organization: "Test", CommonName: "localhost", } - cert, err := NewCertificate(info, key) + + cert, err := crypto.NewCertificate(info, key) if err != nil { t.Fatal(err) } - if err := cert.Sign(key, EVP_SHA256); err != nil { + + if err := cert.Sign(key, crypto.MDSHA256); err != nil { t.Fatal(err) } } func TestCertGenerateSM2(t *testing.T) { - key, err := GenerateECKey(Sm2Curve) + t.Parallel() + + key, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { t.Fatal(err) } - info := &CertificateInfo{ + + info := &crypto.CertificateInfo{ Serial: big.NewInt(int64(1)), Issued: 0, Expires: 24 * time.Hour, @@ -57,21 +67,26 @@ func TestCertGenerateSM2(t *testing.T) { Organization: "Test", CommonName: "localhost", } - cert, err := NewCertificate(info, key) + + cert, err := crypto.NewCertificate(info, key) if err != nil { t.Fatal(err) } - if err := cert.Sign(key, EVP_SM3); err != nil { + + if err := cert.Sign(key, crypto.MDSM3); err != nil { t.Fatal(err) } } func TestCAGenerate(t *testing.T) { - cakey, err := GenerateRSAKey(768) + t.Parallel() + + cakey, err := crypto.GenerateRSAKey(768) if err != nil { t.Fatal(err) } - info := &CertificateInfo{ + + info := &crypto.CertificateInfo{ Serial: big.NewInt(int64(1)), Issued: 0, Expires: 24 * time.Hour, @@ -79,26 +94,31 @@ func TestCAGenerate(t *testing.T) { Organization: "Test CA", CommonName: "CA", } - ca, err := NewCertificate(info, cakey) + + ca, err := crypto.NewCertificate(info, cakey) if err != nil { t.Fatal(err) } - if err := ca.AddExtensions(map[NID]string{ - NID_basic_constraints: "critical,CA:TRUE", - NID_key_usage: "critical,keyCertSign,cRLSign", - NID_subject_key_identifier: "hash", - NID_netscape_cert_type: "sslCA", + + if err := ca.AddExtensions(map[crypto.NID]string{ + crypto.NidBasicConstraints: "critical,CA:TRUE", + crypto.NidKeyUsage: "critical,keyCertSign,cRLSign", + crypto.NidSubjectKeyIdentifier: "hash", + crypto.NidNetscapeCertType: "sslCA", }); err != nil { t.Fatal(err) } - if err := ca.Sign(cakey, EVP_SHA256); err != nil { + + if err := ca.Sign(cakey, crypto.MDSHA256); err != nil { t.Fatal(err) } - key, err := GenerateRSAKey(768) + + key, err := crypto.GenerateRSAKey(768) if err != nil { t.Fatal(err) } - info = &CertificateInfo{ + + info = &crypto.CertificateInfo{ Serial: big.NewInt(int64(1)), Issued: 0, Expires: 24 * time.Hour, @@ -106,78 +126,89 @@ func TestCAGenerate(t *testing.T) { Organization: "Test", CommonName: "localhost", } - cert, err := NewCertificate(info, key) + + cert, err := crypto.NewCertificate(info, key) if err != nil { t.Fatal(err) } - if err := cert.AddExtensions(map[NID]string{ - NID_basic_constraints: "critical,CA:FALSE", - NID_key_usage: "keyEncipherment", - NID_ext_key_usage: "serverAuth", + + if err := cert.AddExtensions(map[crypto.NID]string{ + crypto.NidBasicConstraints: "critical,CA:FALSE", + crypto.NidKeyUsage: "keyEncipherment", + crypto.NidExtKeyUsage: "serverAuth", }); err != nil { t.Fatal(err) } + if err := cert.SetIssuer(ca); err != nil { t.Fatal(err) } - if err := cert.Sign(cakey, EVP_SHA256); err != nil { + + if err := cert.Sign(cakey, crypto.MDSHA256); err != nil { t.Fatal(err) } } +func generateSM2KeyAndSave(t *testing.T, filename string) crypto.PrivateKey { + t.Helper() + + key, err := crypto.GenerateECKey(crypto.SM2Curve) + if err != nil { + t.Fatal(err) + } + + pem, err := key.MarshalPKCS8PrivateKeyPEM() + if err != nil { + t.Fatal(err) + } + + err = crypto.SavePEMToFile(pem, filename) + if err != nil { + t.Fatal(err) + } + + return key +} + func TestCAGenerateSM2(t *testing.T) { + t.Parallel() + dirName := filepath.Join("test-runs", "TestCAGenerateSM2") _, err := os.Stat(dirName) + if os.IsNotExist(err) { - // The directory does not exist, creating it now. - err := os.MkdirAll(dirName, 0755) + err := os.MkdirAll(dirName, 0o755) if err != nil { t.Logf("Failed to create the directory: %v\n", err) } } else if err != nil { - // other error t.Logf("Failed to check the directory: %v\n", err) } - // Helper function: generate and save key - generateAndSaveKey := func(filename string) PrivateKey { - key, err := GenerateECKey(Sm2Curve) - if err != nil { - t.Fatal(err) - } - pem, err := key.MarshalPKCS8PrivateKeyPEM() + signAndSaveCert := func(cert *crypto.Certificate, caKey crypto.PrivateKey, filename string) { + err := cert.Sign(caKey, crypto.MDSM3) if err != nil { t.Fatal(err) } - err = SavePEMToFile(pem, filename) - if err != nil { - t.Fatal(err) - } - return key - } - // Helper function: sign and save certificate - signAndSaveCert := func(cert *Certificate, caKey PrivateKey, filename string) { - err := cert.Sign(caKey, EVP_SM3) - if err != nil { - t.Fatal(err) - } certPem, err := cert.MarshalPEM() if err != nil { t.Fatal(err) } - err = SavePEMToFile(certPem, filename) + + err = crypto.SavePEMToFile(certPem, filename) if err != nil { t.Fatal(err) } } // Create CA certificate - caKey, err := GenerateECKey(Sm2Curve) + caKey, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { t.Fatal(err) } - caInfo := CertificateInfo{ + + caInfo := crypto.CertificateInfo{ big.NewInt(1), 0, 87600 * time.Hour, // 10 years @@ -185,24 +216,26 @@ func TestCAGenerateSM2(t *testing.T) { "Test CA", "CA", } - caExtensions := map[NID]string{ - NID_basic_constraints: "critical,CA:TRUE", - NID_key_usage: "critical,digitalSignature,keyCertSign,cRLSign", - NID_subject_key_identifier: "hash", - NID_authority_key_identifier: "keyid:always,issuer", + caExtensions := map[crypto.NID]string{ + crypto.NidBasicConstraints: "critical,CA:TRUE", + crypto.NidKeyUsage: "critical,digitalSignature,keyCertSign,cRLSign", + crypto.NidSubjectKeyIdentifier: "hash", + crypto.NidAuthorityKeyIdentifier: "keyid:always,issuer", } - ca, err := NewCertificate(&caInfo, caKey) + + ca, err := crypto.NewCertificate(&caInfo, caKey) if err != nil { t.Fatal(err) } + err = ca.AddExtensions(caExtensions) if err != nil { t.Fatal(err) } + caFile := filepath.Join(dirName, "chain-ca.crt") signAndSaveCert(ca, caKey, caFile) - // Define additional certificate information certInfos := []struct { name string keyUsage string @@ -213,11 +246,10 @@ func TestCAGenerateSM2(t *testing.T) { {"client_enc", "keyAgreement, keyEncipherment, dataEncipherment"}, } - // Create additional certificates for _, info := range certInfos { keyFile := filepath.Join(dirName, info.name+".key") - key := generateAndSaveKey(keyFile) - certInfo := CertificateInfo{ + key := generateSM2KeyAndSave(t, keyFile) + certInfo := crypto.CertificateInfo{ Serial: big.NewInt(1), Issued: 0, Expires: 87600 * time.Hour, // 10 years @@ -225,33 +257,40 @@ func TestCAGenerateSM2(t *testing.T) { Organization: "Test", CommonName: "localhost", } - extensions := map[NID]string{ - NID_basic_constraints: "critical,CA:FALSE", - NID_key_usage: info.keyUsage, + extensions := map[crypto.NID]string{ + crypto.NidBasicConstraints: "critical,CA:FALSE", + crypto.NidKeyUsage: info.keyUsage, } - cert, err := NewCertificate(&certInfo, key) + + cert, err := crypto.NewCertificate(&certInfo, key) if err != nil { t.Fatal(err) } + err = cert.AddExtensions(extensions) if err != nil { t.Fatal(err) } + err = cert.SetIssuer(ca) if err != nil { t.Fatal(err) } + certFile := filepath.Join(dirName, info.name+".crt") signAndSaveCert(cert, caKey, certFile) } } func TestCertGetNameEntry(t *testing.T) { - key, err := GenerateRSAKey(768) + t.Parallel() + + key, err := crypto.GenerateRSAKey(768) if err != nil { t.Fatal(err) } - info := &CertificateInfo{ + + info := &crypto.CertificateInfo{ Serial: big.NewInt(int64(1)), Issued: 0, Expires: 24 * time.Hour, @@ -259,36 +298,45 @@ func TestCertGetNameEntry(t *testing.T) { Organization: "Test", CommonName: "localhost", } - cert, err := NewCertificate(info, key) + + cert, err := crypto.NewCertificate(info, key) if err != nil { t.Fatal(err) } + name, err := cert.GetSubjectName() if err != nil { t.Fatal(err) } - entry, ok := name.GetEntry(NID_commonName) + + entry, ok := name.GetEntry(crypto.NidCommonName) if !ok { t.Fatal("no common name") } + if entry != "localhost" { t.Fatalf("expected localhost; got %q", entry) } - entry, ok = name.GetEntry(NID_localityName) + + entry, ok = name.GetEntry(crypto.NidLocalityName) if ok { t.Fatal("did not expect a locality name") } + if entry != "" { t.Fatalf("entry should be empty; got %q", entry) } } func TestCertVersion(t *testing.T) { - key, err := GenerateRSAKey(768) + t.Parallel() + + key, err := crypto.GenerateRSAKey(768) if err != nil { t.Fatal(err) } - info := &CertificateInfo{ + + info := &crypto.CertificateInfo{ Serial: big.NewInt(int64(1)), Issued: 0, Expires: 24 * time.Hour, @@ -296,14 +344,17 @@ func TestCertVersion(t *testing.T) { Organization: "Test", CommonName: "localhost", } - cert, err := NewCertificate(info, key) + + cert, err := crypto.NewCertificate(info, key) if err != nil { t.Fatal(err) } - if err := cert.SetVersion(X509_V3); err != nil { + + if err := cert.SetVersion(crypto.X509V3); err != nil { t.Fatal(err) } - if vers := cert.GetVersion(); vers != X509_V3 { + + if vers := cert.GetVersion(); vers != crypto.X509V3 { t.Fatalf("bad version: %d", vers) } } diff --git a/crypto/ciphers.go b/crypto/ciphers.go index 7f75c66..ff15af3 100644 --- a/crypto/ciphers.go +++ b/crypto/ciphers.go @@ -18,24 +18,23 @@ package crypto import "C" import ( - "errors" "fmt" "runtime" "unsafe" ) const ( - GCM_TAG_MAXLEN = 16 + GCMTagMaxLen = 16 ) const ( - CIPHER_MODE_ECB = 1 - CIPHER_MODE_CBC = 2 - CIPHER_MODE_CFB = 3 - CIPHER_MODE_OFB = 4 - CIPHER_MODE_CTR = 5 - CIPHER_MODE_GCM = 6 - CIPHER_MODE_CCM = 7 + CipherModeECB = 1 + CipherModeCBC = 2 + CipherModeCFB = 3 + CipherModeOFB = 4 + CipherModeCTR = 5 + CipherModeGCM = 6 + CipherModeCCM = 7 ) type CipherCtx interface { @@ -83,7 +82,7 @@ func (c *Cipher) IVSize() int { func Nid2ShortName(nid NID) (string, error) { sn := C.OBJ_nid2sn(C.int(nid)) if sn == nil { - return "", fmt.Errorf("NID %d not found", nid) + return "", PopError() } return C.GoString(sn), nil } @@ -93,7 +92,7 @@ func GetCipherByName(name string) (*Cipher, error) { defer C.free(unsafe.Pointer(cname)) p := C.EVP_get_cipherbyname(cname) if p == nil { - return nil, fmt.Errorf("Cipher %v not found", name) + return nil, ErrCipherNotFound } // we can consider ciphers to use static mem; don't need to free return &Cipher{ptr: p}, nil @@ -114,7 +113,7 @@ type cipherCtx struct { func newCipherCtx() (*cipherCtx, error) { cctx := C.EVP_CIPHER_CTX_new() if cctx == nil { - return nil, errors.New("failed to allocate cipher context") + return nil, ErrMallocFailure } ctx := &cipherCtx{cctx} runtime.SetFinalizer(ctx, func(ctx *cipherCtx) { @@ -127,15 +126,15 @@ func (ctx *cipherCtx) SetKeyAndIV(key, iv []byte) error { var kptr, iptr *C.uchar if key != nil { if len(key) != ctx.KeySize() { - return fmt.Errorf("bad key size (%d bytes instead of %d)", - len(key), ctx.KeySize()) + return fmt.Errorf("bad key size (%d bytes instead of %d): %w", + len(key), ctx.KeySize(), ErrBadKeySize) } kptr = (*C.uchar)(&key[0]) } if iv != nil { if len(iv) != ctx.IVSize() { - return fmt.Errorf("bad IV size (%d bytes instead of %d)", - len(iv), ctx.IVSize()) + return fmt.Errorf("bad IV size (%d bytes instead of %d): %w", + len(iv), ctx.IVSize(), ErrBadIvSize) } iptr = (*C.uchar)(&iv[0]) } @@ -146,8 +145,8 @@ func (ctx *cipherCtx) SetKeyAndIV(key, iv []byte) error { } else { res = C.EVP_DecryptInit_ex(ctx.ctx, nil, nil, kptr, iptr) } - if 1 != res { - return errors.New("failed to apply key/IV") + if res != 1 { + return PopError() } } return nil @@ -184,8 +183,8 @@ func (ctx *cipherCtx) SetPadding(pad bool) { func (ctx *cipherCtx) SetCtrl(code, arg int) error { res := C.EVP_CIPHER_CTX_ctrl(ctx.ctx, C.int(code), C.int(arg), nil) if res != 1 { - return fmt.Errorf("failed to set code %d to %d [result %d]", - code, arg, res) + return fmt.Errorf("failed to set code %d to %d [result %d]: %w", + code, arg, res, PopError()) } return nil } @@ -194,8 +193,8 @@ func (ctx *cipherCtx) SetCtrlBytes(code, arg int, value []byte) error { res := C.EVP_CIPHER_CTX_ctrl(ctx.ctx, C.int(code), C.int(arg), unsafe.Pointer(&value[0])) if res != 1 { - return fmt.Errorf("failed to set code %d with arg %d to %x [result %d]", - code, arg, value, res) + return fmt.Errorf("failed to set code %d with arg %d to %x [result %d]: %w", + code, arg, value, res, PopError()) } return nil } @@ -205,8 +204,8 @@ func (ctx *cipherCtx) GetCtrlInt(code, arg int) (int, error) { res := C.EVP_CIPHER_CTX_ctrl(ctx.ctx, C.int(code), C.int(arg), unsafe.Pointer(&returnVal)) if res != 1 { - return 0, fmt.Errorf("failed to get code %d with arg %d [result %d]", - code, arg, res) + return 0, fmt.Errorf("failed to get code %d with arg %d [result %d]: %w", + code, arg, res, PopError()) } return int(returnVal), nil } @@ -216,8 +215,8 @@ func (ctx *cipherCtx) GetCtrlBytes(code, arg, expectsize int) ([]byte, error) { res := C.EVP_CIPHER_CTX_ctrl(ctx.ctx, C.int(code), C.int(arg), unsafe.Pointer(&returnVal[0])) if res != 1 { - return nil, fmt.Errorf("failed to get code %d with arg %d [result %d]", - code, arg, res) + return nil, fmt.Errorf("failed to get code %d with arg %d [result %d]: %w", + code, arg, res, PopError()) } return returnVal, nil } @@ -255,10 +254,11 @@ type decryptionCipherCtx struct { *cipherCtx } -func newEncryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( - *encryptionCipherCtx, error) { - if c == nil { - return nil, errors.New("null cipher not allowed") +func newEncryptionCipherCtx(cipher *Cipher, e *Engine, key, iv []byte) ( + *encryptionCipherCtx, error, +) { + if cipher == nil { + return nil, ErrNilParameter } ctx, err := newCipherCtx() if err != nil { @@ -266,10 +266,10 @@ func newEncryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( } var eptr *C.ENGINE if e != nil { - eptr = (*C.ENGINE)(e.Engine()) + eptr = e.Engine() } - if 1 != C.EVP_EncryptInit_ex(ctx.ctx, c.ptr, eptr, nil, nil) { - return nil, errors.New("failed to initialize cipher context") + if C.EVP_EncryptInit_ex(ctx.ctx, cipher.ptr, eptr, nil, nil) != 1 { + return nil, PopError() } err = ctx.SetKeyAndIV(key, iv) if err != nil { @@ -278,10 +278,11 @@ func newEncryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( return &encryptionCipherCtx{cipherCtx: ctx}, nil } -func newDecryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( - *decryptionCipherCtx, error) { - if c == nil { - return nil, errors.New("null cipher not allowed") +func newDecryptionCipherCtx(cipher *Cipher, e *Engine, key, iv []byte) ( + *decryptionCipherCtx, error, +) { + if cipher == nil { + return nil, ErrNilParameter } ctx, err := newCipherCtx() if err != nil { @@ -289,10 +290,10 @@ func newDecryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( } var eptr *C.ENGINE if e != nil { - eptr = (*C.ENGINE)(e.Engine()) + eptr = e.Engine() } - if 1 != C.EVP_DecryptInit_ex(ctx.ctx, c.ptr, eptr, nil, nil) { - return nil, errors.New("failed to initialize cipher context") + if C.EVP_DecryptInit_ex(ctx.ctx, cipher.ptr, eptr, nil, nil) != 1 { + return nil, PopError() } err = ctx.SetKeyAndIV(key, iv) if err != nil { @@ -302,12 +303,14 @@ func newDecryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( } func NewEncryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( - EncryptionCipherCtx, error) { + EncryptionCipherCtx, error, +) { return newEncryptionCipherCtx(c, e, key, iv) } func NewDecryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( - DecryptionCipherCtx, error) { + DecryptionCipherCtx, error, +) { return newDecryptionCipherCtx(c, e, key, iv) } @@ -320,7 +323,7 @@ func (ctx *encryptionCipherCtx) EncryptUpdate(input []byte) ([]byte, error) { res := C.EVP_EncryptUpdate(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen, (*C.uchar)(&input[0]), C.int(len(input))) if res != 1 { - return nil, fmt.Errorf("failed to encrypt [result %d]", res) + return nil, fmt.Errorf("failed to encrypt [result %d]: %w", res, PopError()) } return outbuf[:outlen], nil } @@ -334,7 +337,7 @@ func (ctx *decryptionCipherCtx) DecryptUpdate(input []byte) ([]byte, error) { res := C.EVP_DecryptUpdate(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen, (*C.uchar)(&input[0]), C.int(len(input))) if res != 1 { - return nil, fmt.Errorf("failed to decrypt [result %d]", res) + return nil, fmt.Errorf("failed to decrypt [result %d]: %w", res, PopError()) } return outbuf[:outlen], nil } @@ -342,8 +345,8 @@ func (ctx *decryptionCipherCtx) DecryptUpdate(input []byte) ([]byte, error) { func (ctx *encryptionCipherCtx) EncryptFinal() ([]byte, error) { outbuf := make([]byte, ctx.BlockSize()) var outlen C.int - if 1 != C.EVP_EncryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) { - return nil, errors.New("encryption failed") + if C.EVP_EncryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) != 1 { + return nil, fmt.Errorf("encryption failed: %w", PopError()) } return outbuf[:outlen], nil } @@ -351,10 +354,10 @@ func (ctx *encryptionCipherCtx) EncryptFinal() ([]byte, error) { func (ctx *decryptionCipherCtx) DecryptFinal() ([]byte, error) { outbuf := make([]byte, ctx.BlockSize()) var outlen C.int - if 1 != C.EVP_DecryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) { + if C.EVP_DecryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) != 1 { // this may mean the tag failed to verify- all previous plaintext // returned must be considered faked and invalid - return nil, errors.New("decryption failed") + return nil, fmt.Errorf("decryption failed: %w", PopError()) } return outbuf[:outlen], nil } diff --git a/crypto/ciphers_gcm.go b/crypto/ciphers_gcm.go index afcbde1..24e8ce4 100644 --- a/crypto/ciphers_gcm.go +++ b/crypto/ciphers_gcm.go @@ -18,7 +18,6 @@ package crypto import "C" import ( - "errors" "fmt" ) @@ -29,7 +28,7 @@ type AuthenticatedEncryptionCipherCtx interface { // not encrypted itself, but is part of the authenticated data. when // decrypting or authenticating, pass back with the decryption // context's ExtraData() - ExtraData([]byte) error + ExtraData(extra []byte) error // use after finalizing encryption to get the authenticating tag GetTag() ([]byte, error) @@ -40,11 +39,11 @@ type AuthenticatedDecryptionCipherCtx interface { // pass in any extra data that was added during encryption with the // encryption context's ExtraData() - ExtraData([]byte) error + ExtraData(extra []byte) error // use before finalizing decryption to tell the library what the // tag is expected to be - SetTag([]byte) error + SetTag(tag []byte) error } type authEncryptionCipherCtx struct { @@ -65,13 +64,14 @@ func getGCMCipher(blocksize int) (*Cipher, error) { case 128: cipherptr = C.EVP_aes_128_gcm() default: - return nil, fmt.Errorf("unknown block size %d", blocksize) + return nil, ErrUknownBlockSize } return &Cipher{ptr: cipherptr}, nil } func NewGCMEncryptionCipherCtx(blocksize int, e *Engine, key, iv []byte) ( - AuthenticatedEncryptionCipherCtx, error) { + AuthenticatedEncryptionCipherCtx, error, +) { cipher, err := getGCMCipher(blocksize) if err != nil { return nil, err @@ -83,19 +83,19 @@ func NewGCMEncryptionCipherCtx(blocksize int, e *Engine, key, iv []byte) ( if len(iv) > 0 { err := ctx.SetCtrl(C.EVP_CTRL_GCM_SET_IVLEN, len(iv)) if err != nil { - return nil, fmt.Errorf("could not set IV len to %d: %s", + return nil, fmt.Errorf("could not set IV len to %d: %w", len(iv), err) } - if 1 != C.EVP_EncryptInit_ex(ctx.ctx, nil, nil, nil, - (*C.uchar)(&iv[0])) { - return nil, errors.New("failed to apply IV") + if C.EVP_EncryptInit_ex(ctx.ctx, nil, nil, nil, (*C.uchar)(&iv[0])) != 1 { + return nil, fmt.Errorf("failed to apply IV: %w", PopError()) } } return &authEncryptionCipherCtx{encryptionCipherCtx: ctx}, nil } func NewGCMDecryptionCipherCtx(blocksize int, e *Engine, key, iv []byte) ( - AuthenticatedDecryptionCipherCtx, error) { + AuthenticatedDecryptionCipherCtx, error, +) { cipher, err := getGCMCipher(blocksize) if err != nil { return nil, err @@ -107,12 +107,11 @@ func NewGCMDecryptionCipherCtx(blocksize int, e *Engine, key, iv []byte) ( if len(iv) > 0 { err := ctx.SetCtrl(C.EVP_CTRL_GCM_SET_IVLEN, len(iv)) if err != nil { - return nil, fmt.Errorf("could not set IV len to %d: %s", + return nil, fmt.Errorf("could not set IV len to %d: %w", len(iv), err) } - if 1 != C.EVP_DecryptInit_ex(ctx.ctx, nil, nil, nil, - (*C.uchar)(&iv[0])) { - return nil, errors.New("failed to apply IV") + if C.EVP_DecryptInit_ex(ctx.ctx, nil, nil, nil, (*C.uchar)(&iv[0])) != 1 { + return nil, fmt.Errorf("failed to apply IV: %w", PopError()) } } return &authDecryptionCipherCtx{decryptionCipherCtx: ctx}, nil @@ -123,9 +122,8 @@ func (ctx *authEncryptionCipherCtx) ExtraData(aad []byte) error { return nil } var outlen C.int - if 1 != C.EVP_EncryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), - C.int(len(aad))) { - return errors.New("failed to add additional authenticated data") + if C.EVP_EncryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), C.int(len(aad))) != 1 { + return fmt.Errorf("failed to add additional authenticated data: %w", PopError()) } return nil } @@ -135,16 +133,15 @@ func (ctx *authDecryptionCipherCtx) ExtraData(aad []byte) error { return nil } var outlen C.int - if 1 != C.EVP_DecryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), - C.int(len(aad))) { - return errors.New("failed to add additional authenticated data") + if C.EVP_DecryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), C.int(len(aad))) != 1 { + return fmt.Errorf("failed to add additional authenticated data: %w", PopError()) } return nil } func (ctx *authEncryptionCipherCtx) GetTag() ([]byte, error) { - return ctx.GetCtrlBytes(C.EVP_CTRL_GCM_GET_TAG, GCM_TAG_MAXLEN, - GCM_TAG_MAXLEN) + return ctx.GetCtrlBytes(C.EVP_CTRL_GCM_GET_TAG, GCMTagMaxLen, + GCMTagMaxLen) } func (ctx *authDecryptionCipherCtx) SetTag(tag []byte) error { diff --git a/crypto/ciphers_test.go b/crypto/ciphers_test.go index f2184f2..43a340b 100644 --- a/crypto/ciphers_test.go +++ b/crypto/ciphers_test.go @@ -12,124 +12,149 @@ // See the License for the specific language governing permissions and // limitations under the License. -package crypto +package crypto_test import ( "bytes" "fmt" "strings" "testing" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) func expectError(t *testing.T, err error, msg string) { + t.Helper() + if err == nil { t.Fatalf("Expected error containing %#v, but got none", msg) } + if !strings.Contains(err.Error(), msg) { t.Fatalf("Expected error containing %#v, but got %s", msg, err) } } func TestBadInputs(t *testing.T) { - _, err := NewGCMEncryptionCipherCtx(256, nil, + t.Parallel() + + _, err := crypto.NewGCMEncryptionCipherCtx(256, nil, []byte("abcdefghijklmnopqrstuvwxyz"), nil) expectError(t, err, "bad key size") - _, err = NewGCMEncryptionCipherCtx(128, nil, + _, err = crypto.NewGCMEncryptionCipherCtx(128, nil, []byte("abcdefghijklmnopqrstuvwxyz"), nil) expectError(t, err, "bad key size") - _, err = NewGCMEncryptionCipherCtx(200, nil, + _, err = crypto.NewGCMEncryptionCipherCtx(200, nil, []byte("abcdefghijklmnopqrstuvwxy"), nil) expectError(t, err, "unknown block size") - c, err := GetCipherByName("AES-128-CBC") + + c, err := crypto.GetCipherByName("AES-128-CBC") if err != nil { t.Fatal("Could not look up AES-128-CBC") } - _, err = NewEncryptionCipherCtx(c, nil, []byte("abcdefghijklmnop"), + + _, err = crypto.NewEncryptionCipherCtx(c, nil, []byte("abcdefghijklmnop"), []byte("abc")) expectError(t, err, "bad IV size") } -func doEncryption(key, iv, aad, plaintext []byte, blocksize, bufsize int) ( - ciphertext, tag []byte, err error) { - ectx, err := NewGCMEncryptionCipherCtx(blocksize, nil, key, iv) +func doEncryption(key, iv, aad, plaintext []byte, blocksize, bufsize int) ([]byte, []byte, error) { + ectx, err := crypto.NewGCMEncryptionCipherCtx(blocksize, nil, key, iv) if err != nil { - return nil, nil, fmt.Errorf("Failed making GCM encryption ctx: %s", err) + return nil, nil, fmt.Errorf("Failed making GCM encryption ctx: %w", err) } + err = ectx.ExtraData(aad) if err != nil { - return nil, nil, fmt.Errorf("Failed to add authenticated data: %s", - err) + return nil, nil, fmt.Errorf("Failed to add authenticated data: %w", err) } + plainb := bytes.NewBuffer(plaintext) cipherb := new(bytes.Buffer) + for plainb.Len() > 0 { moar, err := ectx.EncryptUpdate(plainb.Next(bufsize)) if err != nil { - return nil, nil, fmt.Errorf("Failed to perform an encryption: %s", - err) + return nil, nil, fmt.Errorf("Failed to perform an encryption: %w", err) } + cipherb.Write(moar) } + moar, err := ectx.EncryptFinal() if err != nil { - return nil, nil, fmt.Errorf("Failed to finalize encryption: %s", err) + return nil, nil, fmt.Errorf("Failed to finalize encryption: %w", err) } + cipherb.Write(moar) - tag, err = ectx.GetTag() + + tag, err := ectx.GetTag() if err != nil { - return nil, nil, fmt.Errorf("Failed to get GCM tag: %s", err) + return nil, nil, fmt.Errorf("Failed to get GCM tag: %w", err) } + return cipherb.Bytes(), tag, nil } -func doDecryption(key, iv, aad, ciphertext, tag []byte, blocksize, - bufsize int) (plaintext []byte, err error) { - dctx, err := NewGCMDecryptionCipherCtx(blocksize, nil, key, iv) +func doDecryption(key, iv, aad, ciphertext, tag []byte, blocksize, bufsize int) ([]byte, error) { + dctx, err := crypto.NewGCMDecryptionCipherCtx(blocksize, nil, key, iv) if err != nil { - return nil, fmt.Errorf("Failed making GCM decryption ctx: %s", err) + return nil, fmt.Errorf("Failed making GCM decryption ctx: %w", err) } + aadbuf := bytes.NewBuffer(aad) for aadbuf.Len() > 0 { err = dctx.ExtraData(aadbuf.Next(bufsize)) if err != nil { - return nil, fmt.Errorf("Failed to add authenticated data: %s", err) + return nil, fmt.Errorf("Failed to add authenticated data: %w", err) } } + plainb := new(bytes.Buffer) + cipherb := bytes.NewBuffer(ciphertext) for cipherb.Len() > 0 { moar, err := dctx.DecryptUpdate(cipherb.Next(bufsize)) if err != nil { - return nil, fmt.Errorf("Failed to perform a decryption: %s", err) + return nil, fmt.Errorf("Failed to perform a decryption: %w", err) } + plainb.Write(moar) } + err = dctx.SetTag(tag) if err != nil { - return nil, fmt.Errorf("Failed to set expected GCM tag: %s", err) + return nil, fmt.Errorf("Failed to set expected GCM tag: %w", err) } + moar, err := dctx.DecryptFinal() if err != nil { - return nil, fmt.Errorf("Failed to finalize decryption: %s", err) + return nil, fmt.Errorf("Failed to finalize decryption: %w", err) } + plainb.Write(moar) + return plainb.Bytes(), nil } func checkEqual(t *testing.T, output []byte, original string) { - output_s := string(output) - if output_s != original { - t.Fatalf("output != original! %#v != %#v", output_s, original) + t.Helper() + + outputStr := string(output) + if outputStr != original { + t.Fatalf("output != original! %#v != %#v", outputStr, original) } } func TestGCM(t *testing.T) { + t.Parallel() + aad := []byte("foo bar baz") key := []byte("nobody can guess this i'm sure..") // len=32 iv := []byte("just a bunch of bytes") plaintext := "Long long ago, in a land far away..." - blocksizes_to_test := []int{256, 192, 128} + blocksizesToTest := []int{256, 192, 128} // best for this to have no common factors with blocksize, so that the // buffering layer inside the CIPHER_CTX gets exercised @@ -139,40 +164,49 @@ func TestGCM(t *testing.T) { plaintext += "!" // make sure padding is exercised } - for _, bsize := range blocksizes_to_test { + for _, bsize := range blocksizesToTest { subkey := key[:bsize/8] + ciphertext, tag, err := doEncryption(subkey, iv, aad, []byte(plaintext), bsize, bufsize) if err != nil { t.Fatalf("Encryption with b=%d: %s", bsize, err) } - plaintext_out, err := doDecryption(subkey, iv, aad, ciphertext, tag, + + plaintextOut, err := doDecryption(subkey, iv, aad, ciphertext, tag, bsize, bufsize) if err != nil { t.Fatalf("Decryption with b=%d: %s", bsize, err) } - checkEqual(t, plaintext_out, plaintext) + + checkEqual(t, plaintextOut, plaintext) } } func TestGCMWithNoAAD(t *testing.T) { + t.Parallel() + key := []byte("0000111122223333") iv := []byte("9999") - plaintext := "ABORT ABORT ABORT DANGAR" + plaintext := "ABORT DANGAR" ciphertext, tag, err := doEncryption(key, iv, nil, []byte(plaintext), 128, 32) if err != nil { t.Fatal("Encryption failure:", err) } - plaintext_out, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129) + + plaintextOut, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + + checkEqual(t, plaintextOut, plaintext) } func TestBadTag(t *testing.T) { + t.Parallel() + key := []byte("abcdefghijklmnop") iv := []byte("v7239qjfv3qr793fuaj") plaintext := "The red rooster has flown the coop I REPEAT" + @@ -185,20 +219,25 @@ func TestBadTag(t *testing.T) { } // flip the last bit tag[len(tag)-1] ^= 1 - plaintext_out, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129) + + _, err = doDecryption(key, iv, nil, ciphertext, tag, 128, 129) if err == nil { t.Fatal("Expected error for bad tag, but got none") } // flip it back, try again just to make sure tag[len(tag)-1] ^= 1 - plaintext_out, err = doDecryption(key, iv, nil, ciphertext, tag, 128, 129) + + plaintextOut, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + + checkEqual(t, plaintextOut, plaintext) } func TestBadCiphertext(t *testing.T) { + t.Parallel() + key := []byte("hard boiled eggs & bacon") iv := []byte("x") // it's not a very /good/ IV, is it aad := []byte("mu") @@ -211,20 +250,25 @@ func TestBadCiphertext(t *testing.T) { } // flip the last bit ciphertext[len(ciphertext)-1] ^= 1 - plaintext_out, err := doDecryption(key, iv, aad, ciphertext, tag, 192, 192) + + _, err = doDecryption(key, iv, aad, ciphertext, tag, 192, 192) if err == nil { t.Fatal("Expected error for bad ciphertext, but got none") } // flip it back, try again just to make sure ciphertext[len(ciphertext)-1] ^= 1 - plaintext_out, err = doDecryption(key, iv, aad, ciphertext, tag, 192, 192) + + plaintextOut, err := doDecryption(key, iv, aad, ciphertext, tag, 192, 192) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + + checkEqual(t, plaintextOut, plaintext) } func TestBadAAD(t *testing.T) { + t.Parallel() + key := []byte("Ive got a lovely buncha coconuts") iv := []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab") aad := []byte("Hi i am a plain") @@ -237,68 +281,85 @@ func TestBadAAD(t *testing.T) { } // flip the last bit aad[len(aad)-1] ^= 1 - plaintext_out, err := doDecryption(key, iv, aad, ciphertext, tag, 256, 256) + + _, err = doDecryption(key, iv, aad, ciphertext, tag, 256, 256) if err == nil { t.Fatal("Expected error for bad AAD, but got none") } // flip it back, try again just to make sure aad[len(aad)-1] ^= 1 - plaintext_out, err = doDecryption(key, iv, aad, ciphertext, tag, 256, 256) + + plaintextOut, err := doDecryption(key, iv, aad, ciphertext, tag, 256, 256) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + + checkEqual(t, plaintextOut, plaintext) } func TestNonAuthenticatedEncryption(t *testing.T) { + t.Parallel() + key := []byte("never gonna give you up, never g") iv := []byte("onna let you dow") plaintext1 := "n, never gonna run around" plaintext2 := " and desert you" - cipher, err := GetCipherByName("aes-256-cbc") + cipher, err := crypto.GetCipherByName("aes-256-cbc") if err != nil { t.Fatal("Could not get cipher: ", err) } - eCtx, err := NewEncryptionCipherCtx(cipher, nil, key, iv) + eCtx, err := crypto.NewEncryptionCipherCtx(cipher, nil, key, iv) if err != nil { t.Fatal("Could not create encryption context: ", err) } + cipherbytes, err := eCtx.EncryptUpdate([]byte(plaintext1)) if err != nil { t.Fatal("EncryptUpdate(plaintext1) failure: ", err) } + ciphertext := string(cipherbytes) + cipherbytes, err = eCtx.EncryptUpdate([]byte(plaintext2)) if err != nil { t.Fatal("EncryptUpdate(plaintext2) failure: ", err) } + ciphertext += string(cipherbytes) + cipherbytes, err = eCtx.EncryptFinal() if err != nil { t.Fatal("EncryptFinal() failure: ", err) } + ciphertext += string(cipherbytes) - dCtx, err := NewDecryptionCipherCtx(cipher, nil, key, iv) + dCtx, err := crypto.NewDecryptionCipherCtx(cipher, nil, key, iv) if err != nil { t.Fatal("Could not create decryption context: ", err) } + plainbytes, err := dCtx.DecryptUpdate([]byte(ciphertext[:15])) if err != nil { t.Fatal("DecryptUpdate(ciphertext part 1) failure: ", err) } + plainOutput := string(plainbytes) + plainbytes, err = dCtx.DecryptUpdate([]byte(ciphertext[15:])) if err != nil { t.Fatal("DecryptUpdate(ciphertext part 2) failure: ", err) } + plainOutput += string(plainbytes) + plainbytes, err = dCtx.DecryptFinal() if err != nil { t.Fatal("DecryptFinal() failure: ", err) } + plainOutput += string(plainbytes) checkEqual(t, []byte(plainOutput), plaintext1+plaintext2) diff --git a/crypto/dh.go b/crypto/dh.go index a2fc39a..41ac47d 100644 --- a/crypto/dh.go +++ b/crypto/dh.go @@ -16,10 +16,6 @@ package crypto // #include "shim.h" import "C" -import ( - "errors" - "unsafe" -) // DeriveSharedSecret derives a shared secret using a private key and a peer's // public key. @@ -29,38 +25,38 @@ func DeriveSharedSecret(private PrivateKey, public PublicKey) ([]byte, error) { // Create context for the shared secret derivation dhCtx := C.EVP_PKEY_CTX_new(private.EvpPKey(), nil) if dhCtx == nil { - return nil, errors.New("failed creating shared secret derivation context") + return nil, PopError() } defer C.EVP_PKEY_CTX_free(dhCtx) // Initialize the context if int(C.EVP_PKEY_derive_init(dhCtx)) != 1 { - return nil, errors.New("failed initializing shared secret derivation context") + return nil, PopError() } // Provide the peer's public key if int(C.EVP_PKEY_derive_set_peer(dhCtx, public.EvpPKey())) != 1 { - return nil, errors.New("failed adding peer public key to context") + return nil, PopError() } // Determine how large of a buffer we need for the shared secret var buffLen C.size_t if int(C.EVP_PKEY_derive(dhCtx, nil, &buffLen)) != 1 { - return nil, errors.New("failed determining shared secret length") + return nil, PopError() } // Allocate a buffer buffer := C.X_OPENSSL_malloc(buffLen) if buffer == nil { - return nil, errors.New("failed allocating buffer for shared secret") + return nil, ErrMallocFailure } defer C.X_OPENSSL_free(buffer) // Derive the shared secret if int(C.EVP_PKEY_derive(dhCtx, (*C.uchar)(buffer), &buffLen)) != 1 { - return nil, errors.New("failed deriving the shared secret") + return nil, PopError() } - secret := C.GoBytes(unsafe.Pointer(buffer), C.int(buffLen)) + secret := C.GoBytes(buffer, C.int(buffLen)) return secret, nil } diff --git a/crypto/dh_test.go b/crypto/dh_test.go index db71266..f069e46 100644 --- a/crypto/dh_test.go +++ b/crypto/dh_test.go @@ -12,35 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -package crypto +package crypto_test import ( "bytes" "testing" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) func TestECDH(t *testing.T) { t.Parallel() - myKey, err := GenerateECKey(Prime256v1) + myKey, err := crypto.GenerateECKey(crypto.Prime256v1) if err != nil { t.Fatal(err) } - peerKey, err := GenerateECKey(Prime256v1) + + peerKey, err := crypto.GenerateECKey(crypto.Prime256v1) if err != nil { t.Fatal(err) } - mySecret, err := DeriveSharedSecret(myKey, peerKey) + mySecret, err := crypto.DeriveSharedSecret(myKey, peerKey) if err != nil { t.Fatal(err) } - theirSecret, err := DeriveSharedSecret(peerKey, myKey) + + theirSecret, err := crypto.DeriveSharedSecret(peerKey, myKey) if err != nil { t.Fatal(err) } - if bytes.Compare(mySecret, theirSecret) != 0 { + if !bytes.Equal(mySecret, theirSecret) { t.Fatal("shared secrets are different") } } diff --git a/crypto/dhparam.go b/crypto/dhparam.go index 3009c7f..9b88d3a 100644 --- a/crypto/dhparam.go +++ b/crypto/dhparam.go @@ -18,7 +18,6 @@ package crypto import "C" import ( - "errors" "runtime" "unsafe" ) @@ -33,20 +32,20 @@ func (dh *DH) GetDH() *C.DH { // LoadDHParametersFromPEM loads the Diffie-Hellman parameters from // a PEM-encoded block. -func LoadDHParametersFromPEM(pem_block []byte) (*DH, error) { - if len(pem_block) == 0 { - return nil, errors.New("empty pem block") +func LoadDHParametersFromPEM(pemBlock []byte) (*DH, error) { + if len(pemBlock) == 0 { + return nil, ErrNoCert } - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pemBlock[0]), + C.int(len(pemBlock))) if bio == nil { - return nil, errors.New("failed creating bio") + return nil, ErrMallocFailure } defer C.BIO_free(bio) params := C.PEM_read_bio_DHparams(bio, nil, nil, nil) if params == nil { - return nil, errors.New("failed reading dh parameters") + return nil, PopError() } dhparams := &DH{dh: params} runtime.SetFinalizer(dhparams, func(dhparams *DH) { diff --git a/crypto/digest.go b/crypto/digest.go index a40e2d0..dd25abc 100644 --- a/crypto/digest.go +++ b/crypto/digest.go @@ -18,7 +18,6 @@ package crypto import "C" import ( - "fmt" "unsafe" ) @@ -38,7 +37,7 @@ func GetDigestByName(name string) (*Digest, error) { defer C.free(unsafe.Pointer(cname)) p := C.X_EVP_get_digestbyname(cname) if p == nil { - return nil, fmt.Errorf("Digest %v not found", name) + return nil, ErrUnsupportedDigest } // we can consider digests to use static mem; don't need to free return &Digest{ptr: p}, nil diff --git a/crypto/engine.go b/crypto/engine.go index 5a1a9a3..3551d41 100644 --- a/crypto/engine.go +++ b/crypto/engine.go @@ -36,18 +36,18 @@ func (e *Engine) Engine() *C.ENGINE { return e.e } -func EngineById(name string) (*Engine, error) { +func EngineByID(name string) (*Engine, error) { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) e := &Engine{ e: C.ENGINE_by_id(cname), } if e.e == nil { - return nil, fmt.Errorf("engine %s missing", name) + return nil, ErrNoEngine } if C.ENGINE_init(e.e) == 0 { C.ENGINE_free(e.e) - return nil, fmt.Errorf("engine %s not initialized", name) + return nil, fmt.Errorf("failed to init engine: %w", PopError()) } runtime.SetFinalizer(e, func(e *Engine) { C.ENGINE_finish(e.e) diff --git a/crypto/hmac.go b/crypto/hmac.go index de1e281..3134b82 100644 --- a/crypto/hmac.go +++ b/crypto/hmac.go @@ -18,7 +18,7 @@ package crypto import "C" import ( - "errors" + "fmt" "runtime" "unsafe" ) @@ -29,63 +29,58 @@ type HMAC struct { md *C.EVP_MD } -func NewHMAC(key []byte, digestAlgorithm EVP_MD) (*HMAC, error) { +func NewHMAC(key []byte, digestAlgorithm MDAlgo) (*HMAC, error) { return NewHMACWithEngine(key, digestAlgorithm, nil) } -func NewHMACWithEngine(key []byte, digestAlgorithm EVP_MD, e *Engine) (*HMAC, error) { +func NewHMACWithEngine(key []byte, digestAlgorithm MDAlgo, e *Engine) (*HMAC, error) { var md *C.EVP_MD = getDigestFunction(digestAlgorithm) - h := &HMAC{engine: e, md: md} - h.ctx = C.X_HMAC_CTX_new() - if h.ctx == nil { - return nil, errors.New("unable to allocate HMAC_CTX") + hmac := &HMAC{ctx: nil, engine: e, md: md} + hmac.ctx = C.X_HMAC_CTX_new() + if hmac.ctx == nil { + return nil, ErrMallocFailure } - var c_e *C.ENGINE + var cEngine *C.ENGINE if e != nil { - c_e = (*C.ENGINE)(e.Engine()) + cEngine = e.Engine() } - if rc := C.X_HMAC_Init_ex(h.ctx, - unsafe.Pointer(&key[0]), - C.int(len(key)), - md, - c_e); rc != 1 { - C.X_HMAC_CTX_free(h.ctx) - return nil, errors.New("failed to initialize HMAC_CTX") + if rc := C.X_HMAC_Init_ex(hmac.ctx, unsafe.Pointer(&key[0]), C.int(len(key)), md, cEngine); rc != 1 { + C.X_HMAC_CTX_free(hmac.ctx) + return nil, fmt.Errorf("failed to init HMAC_CTX: %w", PopError()) } - runtime.SetFinalizer(h, func(h *HMAC) { h.Close() }) - return h, nil + runtime.SetFinalizer(hmac, func(h *HMAC) { h.Close() }) + return hmac, nil } func (h *HMAC) Close() { C.X_HMAC_CTX_free(h.ctx) } -func (h *HMAC) Write(data []byte) (n int, err error) { +func (h *HMAC) Write(data []byte) (int, error) { if len(data) == 0 { return 0, nil } - if rc := C.X_HMAC_Update(h.ctx, (*C.uchar)(unsafe.Pointer(&data[0])), - C.size_t(len(data))); rc != 1 { - return 0, errors.New("failed to update HMAC") + if C.X_HMAC_Update(h.ctx, (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) != 1 { + return 0, fmt.Errorf("failed to update HMAC: %w", PopError()) } return len(data), nil } func (h *HMAC) Reset() error { - if 1 != C.X_HMAC_Init_ex(h.ctx, nil, 0, nil, nil) { - return errors.New("failed to reset HMAC_CTX") + if C.X_HMAC_Init_ex(h.ctx, nil, 0, nil, nil) != 1 { + return fmt.Errorf("failed to reset HMAC_CTX: %w", PopError()) } return nil } -func (h *HMAC) Final() (result []byte, err error) { +func (h *HMAC) Final() ([]byte, error) { mdLength := C.X_EVP_MD_size(h.md) - result = make([]byte, mdLength) + result := make([]byte, mdLength) if rc := C.X_HMAC_Final(h.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), (*C.uint)(unsafe.Pointer(&mdLength))); rc != 1 { - return nil, errors.New("failed to finalized HMAC") + return nil, fmt.Errorf("failed to final HMAC: %w", PopError()) } return result, h.Reset() } diff --git a/crypto/hmac_test.go b/crypto/hmac_test.go index a429866..51f02ae 100644 --- a/crypto/hmac_test.go +++ b/crypto/hmac_test.go @@ -12,31 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package crypto +package crypto_test import ( "crypto/hmac" "crypto/sha256" "encoding/hex" "testing" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) func TestSHA256HMAC(t *testing.T) { + t.Parallel() + key := []byte("d741787cc61851af045ccd37") data := []byte("5912EEFD-59EC-43E3-ADB8-D5325AEC3271") - h, err := NewHMAC(key, EVP_SHA256) + tsHmac, err := crypto.NewHMAC(key, crypto.MDSHA256) if err != nil { t.Fatalf("Unable to create new HMAC: %s", err) } - if _, err := h.Write(data); err != nil { + + if _, err := tsHmac.Write(data); err != nil { t.Fatalf("Unable to write data into HMAC: %s", err) } var actualHMACBytes []byte - if actualHMACBytes, err = h.Final(); err != nil { + + if actualHMACBytes, err = tsHmac.Final(); err != nil { t.Fatalf("Error while finalizing HMAC: %s", err) } + actualString := hex.EncodeToString(actualHMACBytes) // generate HMAC with built-in crypto lib @@ -53,19 +60,20 @@ func BenchmarkSHA256HMAC(b *testing.B) { key := []byte("d741787cc61851af045ccd37") data := []byte("5912EEFD-59EC-43E3-ADB8-D5325AEC3271") - h, err := NewHMAC(key, EVP_SHA256) + tsHmac, err := crypto.NewHMAC(key, crypto.MDSHA256) if err != nil { b.Fatalf("Unable to create new HMAC: %s", err) } b.ResetTimer() + for i := 0; i < b.N; i++ { - if _, err := h.Write(data); err != nil { + if _, err := tsHmac.Write(data); err != nil { b.Fatalf("Unable to write data into HMAC: %s", err) } var err error - if _, err = h.Final(); err != nil { + if _, err = tsHmac.Final(); err != nil { b.Fatalf("Error while finalizing HMAC: %s", err) } } diff --git a/crypto/hostname.go b/crypto/hostname.go index 02f0636..1980dcd 100644 --- a/crypto/hostname.go +++ b/crypto/hostname.go @@ -34,15 +34,10 @@ extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen, import "C" import ( - "errors" "net" "unsafe" ) -var ( - ValidationError = errors.New("Host validation error") -) - type CheckFlags int const ( @@ -65,9 +60,13 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error { return nil } if rv == 0 { - return ValidationError + return ErrMatchFailed } - return errors.New("hostname validation had an internal failure") + if rv == -2 { + return ErrInputInvalid + } + + return ErrInternalError } // CheckEmail checks that the X509 certificate is signed for the provided @@ -84,9 +83,13 @@ func (c *Certificate) CheckEmail(email string, flags CheckFlags) error { return nil } if rv == 0 { - return ValidationError + return ErrMatchFailed } - return errors.New("email validation had an internal failure") + if rv == -2 { + return ErrInputInvalid + } + + return ErrInternalError } // CheckIP checks that the X509 certificate is signed for the provided @@ -108,9 +111,13 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error { return nil } if rv == 0 { - return ValidationError + return ErrMatchFailed } - return errors.New("ip validation had an internal failure") + if rv == -2 { + return ErrInputInvalid + } + + return ErrInternalError } // VerifyHostname is a combination of CheckHost and CheckIP. If the provided diff --git a/crypto/init.go b/crypto/init.go index c5ee479..5be290c 100644 --- a/crypto/init.go +++ b/crypto/init.go @@ -16,15 +16,49 @@ import ( "strings" ) +var ( + ErrMallocFailure = errors.New("malloc failure") + ErrNilParameter = errors.New("nil parameter") + ErrNoCipher = errors.New("no cipher") + ErrNoVersion = errors.New("no version") + ErrUnexpectedEOF = errors.New("unexpected EOF") + ErrNoPeerCert = errors.New("no peer certificate") + ErrShutdown = errors.New("shutdown") + ErrNoSession = errors.New("no session") + ErrSessionLength = errors.New("session length error") + ErrEmptySession = errors.New("empty session") + ErrNoALPN = errors.New("no ALPN negotiated") + ErrWrongKeyType = errors.New("wrong key type") + ErrUnknownTLSVersion = errors.New("unknown TLS version") + ErrNoCert = errors.New("no certificate") + ErrNoKey = errors.New("no key") + ErrUnsupportedMode = errors.New("unsupported cipher mode") + ErrPartialWrite = errors.New("partial write") + ErrUnsupportedDigest = errors.New("unsupported digest") + ErrInvalidNid = errors.New("invalid NID") + ErrEmptyExtensionValue = errors.New("empty extension value") + ErrNoPubKey = errors.New("no public key") + ErrCipherNotFound = errors.New("cipher not found") + ErrBadKeySize = errors.New("bad key size") + ErrBadIvSize = errors.New("bad IV size") + ErrUknownBlockSize = errors.New("unknown block size") + ErrNoEngine = errors.New("engine not found") + ErrMatchFailed = errors.New("match failed") + ErrInputInvalid = errors.New("input invalid") + ErrInternalError = errors.New("internal error") + ErrEmptyKey = errors.New("empty key") + ErrNoData = errors.New("no data") +) + func init() { if rc := C.X_tscrypto_init(); rc != 0 { - panic(fmt.Errorf("X_tscrypto_init failed with %d", rc)) + panic(fmt.Sprintf("X_tscrypto_init failed with %d", rc)) } } -// ErrorFromErrorQueue needs to run in the same OS thread as the operation +// PopError needs to run in the same OS thread as the operation // that caused the possible error -func ErrorFromErrorQueue() error { +func PopError() error { var errs []string for { err := C.ERR_get_error() @@ -36,5 +70,6 @@ func ErrorFromErrorQueue() error { C.GoString(C.ERR_func_error_string(err)), C.GoString(C.ERR_reason_error_string(err)))) } - return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n"))) + + return errors.New("error string: " + strings.Join(errs, "\n")) } diff --git a/crypto/init_windows.go b/crypto/init_windows.go index c3f8404..2f1520f 100644 --- a/crypto/init_windows.go +++ b/crypto/init_windows.go @@ -25,9 +25,7 @@ package crypto CRITICAL_SECTION* goopenssl_locks; int go_init_locks() { - int rc = 0; int nlock; - int i; int locks_needed = CRYPTO_num_locks(); goopenssl_locks = (CRITICAL_SECTION*)malloc( diff --git a/crypto/key.go b/crypto/key.go index de194ee..928f304 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -18,48 +18,53 @@ package crypto import "C" import ( - "errors" - "io/ioutil" + "fmt" + "io" "runtime" "unsafe" ) -var ( // some (effectively) constants for tests to refer to - ed25519_support = C.X_ED25519_SUPPORT != 0 -) - type Method *C.EVP_MD -var ( - SHA1_Method Method = C.X_EVP_sha1() - SHA256_Method Method = C.X_EVP_sha256() - SHA512_Method Method = C.X_EVP_sha512() - SM3_Method Method = C.X_EVP_sm3() -) +func SHA1Method() Method { + return C.X_EVP_sha1() +} + +func SHA256Method() Method { + return C.X_EVP_sha256() +} + +func SHA512Method() Method { + return C.X_EVP_sha512() +} + +func SM3Method() Method { + return C.X_EVP_sm3() +} // Constants for the various key types. // Mapping of name -> NID taken from openssl/evp.h const ( - KeyTypeNone = NID_undef - KeyTypeRSA = NID_rsaEncryption - KeyTypeRSA2 = NID_rsa - KeyTypeDSA = NID_dsa - KeyTypeDSA1 = NID_dsa_2 - KeyTypeDSA2 = NID_dsaWithSHA - KeyTypeDSA3 = NID_dsaWithSHA1 - KeyTypeDSA4 = NID_dsaWithSHA1_2 - KeyTypeDH = NID_dhKeyAgreement - KeyTypeDHX = NID_dhpublicnumber - KeyTypeEC = NID_X9_62_id_ecPublicKey - KeyTypeHMAC = NID_hmac - KeyTypeCMAC = NID_cmac - KeyTypeTLS1PRF = NID_tls1_prf - KeyTypeHKDF = NID_hkdf - KeyTypeX25519 = NID_X25519 - KeyTypeX448 = NID_X448 - KeyTypeED25519 = NID_ED25519 - KeyTypeED448 = NID_ED448 - KeyTypeSM2 = NID_sm2 + KeyTypeNone = NidUndef + KeyTypeRSA = NidRsaEncryption + KeyTypeRSA2 = NidRsa + KeyTypeDSA = NidDsa + KeyTypeDSA1 = NidDsa2 + KeyTypeDSA2 = NidDsaWithSHA + KeyTypeDSA3 = NidDsaWithSHA1 + KeyTypeDSA4 = NidDsaWithSHA12 + KeyTypeDH = NidDhKeyAgreement + KeyTypeDHX = NidDhpublicnumber + KeyTypeEC = NidX962IdEcPublicKey + KeyTypeHMAC = NidHmac + KeyTypeCMAC = NidCmac + KeyTypeTLS1PRF = NidTLS1Prf + KeyTypeHKDF = NidHkdf + KeyTypeX25519 = NidX25519 + KeyTypeX448 = NidX448 + KeyTypeED25519 = NidEd25519 + KeyTypeED448 = NidEd448 + KeyTypeSM2 = NidSM2 ) type PublicKey interface { @@ -71,11 +76,11 @@ type PublicKey interface { // MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX // format - MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) + MarshalPKIXPublicKeyPEM() (pemBlock []byte, err error) // MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX // format - MarshalPKIXPublicKeyDER() (der_block []byte, err error) + MarshalPKIXPublicKeyDER() (derBlock []byte, err error) // KeyType returns an identifier for what kind of key is represented by this // object. @@ -100,22 +105,26 @@ type PrivateKey interface { Public() PublicKey // SignPKCS1v15 signs the data using PKCS1.15 - SignPKCS1v15(Method, []byte) ([]byte, error) + SignPKCS1v15(method Method, data []byte) ([]byte, error) // Decrypt decrypts the data using SM2 Decrypt(data []byte) ([]byte, error) // MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1 // format - MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) + MarshalPKCS1PrivateKeyPEM() (pemBlock []byte, err error) // MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1 // format - MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) + MarshalPKCS1PrivateKeyDER() (derBlock []byte, err error) // MarshalPKCS8PrivateKeyPEM converts the private key to PEM-encoded PKCS8 // format - MarshalPKCS8PrivateKeyPEM() (pem_block []byte, err error) + MarshalPKCS8PrivateKeyPEM() (pemBlock []byte, err error) +} + +func SupportEd25519() bool { + return C.X_ED25519_SUPPORT != 0 } type pKey struct { @@ -147,54 +156,48 @@ func (key *pKey) Public() PublicKey { } func (key *pKey) SignPKCS1v15(method Method, data []byte) ([]byte, error) { - ctx := C.X_EVP_MD_CTX_new() defer C.X_EVP_MD_CTX_free(ctx) if key.KeyType() == KeyTypeED25519 { // do ED specific one-shot sign if method != nil || len(data) == 0 { - return nil, errors.New("signpkcs1v15: 0-length data or non-null digest") + return nil, ErrNilParameter } - if 1 != C.X_EVP_DigestSignInit(ctx, nil, nil, nil, key.key) { - return nil, errors.New("signpkcs1v15: failed to init signature") + if C.X_EVP_DigestSignInit(ctx, nil, nil, nil, key.key) != 1 { + return nil, PopError() } - // evp signatures are 64 bytes - sig := make([]byte, 64, 64) - var sigblen C.size_t = 64 - if 1 != C.X_EVP_DigestSign(ctx, - (*C.uchar)(unsafe.Pointer(&sig[0])), - &sigblen, - (*C.uchar)(unsafe.Pointer(&data[0])), - C.size_t(len(data))) { - return nil, errors.New("signpkcs1v15: failed to do one-shot signature") - } + var sigblen C.size_t = C.size_t(C.X_EVP_PKEY_size(key.key)) + sig := make([]byte, sigblen) - return sig[:sigblen], nil - } else { - if 1 != C.X_EVP_DigestSignInit(ctx, nil, method, nil, key.key) { - return nil, errors.New("signpkcs1v15: failed to init signature") + if C.X_EVP_DigestSign(ctx, (*C.uchar)(unsafe.Pointer(&sig[0])), &sigblen, (*C.uchar)(unsafe.Pointer(&data[0])), + C.size_t(len(data))) != 1 { + return nil, PopError() } - if len(data) > 0 { - if 1 != C.X_EVP_DigestSignUpdate( - ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) { - return nil, errors.New("signpkcs1v15: failed to update signature") - } - } + return sig[:sigblen], nil + } - var sigblen C.size_t = C.size_t(C.X_EVP_PKEY_size(key.key)) - sig := make([]byte, sigblen) + if C.X_EVP_DigestSignInit(ctx, nil, method, nil, key.key) != 1 { + return nil, PopError() + } - if 1 != C.X_EVP_DigestSignFinal(ctx, - (*C.uchar)(unsafe.Pointer(&sig[0])), &sigblen) { - return nil, errors.New("signpkcs1v15: failed to finalize signature") + if len(data) > 0 { + if C.X_EVP_DigestSignUpdate(ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) != 1 { + return nil, PopError() } + } - return sig[:sigblen], nil + var sigblen C.size_t = C.size_t(C.X_EVP_PKEY_size(key.key)) + sig := make([]byte, sigblen) + + if C.X_EVP_DigestSignFinal(ctx, (*C.uchar)(unsafe.Pointer(&sig[0])), &sigblen) != 1 { + return nil, PopError() } + + return sig[:sigblen], nil } func (key *pKey) VerifyPKCS1v15(method Method, data, sig []byte) error { @@ -205,63 +208,57 @@ func (key *pKey) VerifyPKCS1v15(method Method, data, sig []byte) error { // do ED specific one-shot sign if method != nil || len(data) == 0 || len(sig) == 0 { - return errors.New("verifypkcs1v15: 0-length data or sig or non-null digest") + return ErrNilParameter } - if 1 != C.X_EVP_DigestVerifyInit(ctx, nil, nil, nil, key.key) { - return errors.New("verifypkcs1v15: failed to init verify") + if C.X_EVP_DigestVerifyInit(ctx, nil, nil, nil, key.key) != 1 { + return PopError() } - if 1 != C.X_EVP_DigestVerify(ctx, - ((*C.uchar)(unsafe.Pointer(&sig[0]))), - C.size_t(len(sig)), - (*C.uchar)(unsafe.Pointer(&data[0])), - C.size_t(len(data))) { - return errors.New("verifypkcs1v15: failed to do one-shot verify") + if C.X_EVP_DigestVerify(ctx, ((*C.uchar)(unsafe.Pointer(&sig[0]))), C.size_t(len(sig)), + (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) != 1 { + return PopError() } return nil + } - } else { - if 1 != C.X_EVP_DigestVerifyInit(ctx, nil, method, nil, key.key) { - return errors.New("verifypkcs1v15: failed to init verify") - } - - if len(data) > 0 { - if 1 != C.X_EVP_DigestVerifyUpdate( - ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) { - return errors.New("verifypkcs1v15: failed to update verify") - } - } + if C.X_EVP_DigestVerifyInit(ctx, nil, method, nil, key.key) != 1 { + return PopError() + } - if 1 != C.X_EVP_DigestVerifyFinal(ctx, - (*C.uchar)(unsafe.Pointer(&sig[0])), C.size_t(len(sig))) { - return errors.New("verifypkcs1v15: failed to finalize verify") + if len(data) > 0 { + if C.X_EVP_DigestVerifyUpdate(ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) != 1 { + return PopError() } + } - return nil + if C.X_EVP_DigestVerifyFinal(ctx, (*C.uchar)(unsafe.Pointer(&sig[0])), C.size_t(len(sig))) != 1 { + return PopError() } + + return nil } func (key *pKey) MarshalPKCS8PrivateKeyPEM() ([]byte, error) { if key.key == nil { - return nil, errors.New("empty key") + return nil, ErrEmptyKey } bio := C.BIO_new(C.BIO_s_mem()) if bio == nil { - return nil, errors.New("failed to allocate memory") + return nil, ErrMallocFailure } defer C.BIO_free(bio) if C.PEM_write_bio_PKCS8PrivateKey(bio, key.key, nil, nil, 0, nil, nil) != 1 { - return nil, errors.New("failed to write private key") + return nil, PopError() } var ptr *C.char length := C.X_BIO_get_mem_data(bio, &ptr) if length <= 0 { - return nil, errors.New("failed to read bio data") + return nil, ErrNoData } result := C.GoBytes(unsafe.Pointer(ptr), C.int(length)) @@ -272,22 +269,20 @@ func (key *pKey) Encrypt(data []byte) ([]byte, error) { ctx := C.EVP_PKEY_CTX_new(key.key, nil) defer C.EVP_PKEY_CTX_free(ctx) - if 1 != C.EVP_PKEY_encrypt_init(ctx) { - return nil, errors.New("encrypt: failed to init encryption") + if C.EVP_PKEY_encrypt_init(ctx) != 1 { + return nil, PopError() } var enclen C.size_t - if 1 != C.EVP_PKEY_encrypt(ctx, nil, &enclen, - (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) { - return nil, errors.New("encrypt: failed to determine encryption length") + if C.EVP_PKEY_encrypt(ctx, nil, &enclen, (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) != 1 { + return nil, PopError() } enc := make([]byte, enclen) - if 1 != C.EVP_PKEY_encrypt(ctx, - (*C.uchar)(unsafe.Pointer(&enc[0])), &enclen, - (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) { - return nil, errors.New("encrypt: failed to finish encryption") + if C.EVP_PKEY_encrypt(ctx, (*C.uchar)(unsafe.Pointer(&enc[0])), &enclen, (*C.uchar)(unsafe.Pointer(&data[0])), + C.size_t(len(data))) != 1 { + return nil, PopError() } return enc[:enclen], nil @@ -296,36 +291,33 @@ func (key *pKey) Encrypt(data []byte) ([]byte, error) { func (key *pKey) Decrypt(data []byte) ([]byte, error) { ctx := C.EVP_PKEY_CTX_new(key.key, nil) if ctx == nil { - return nil, errors.New("decrypt: failed to create context") + return nil, ErrMallocFailure } defer C.EVP_PKEY_CTX_free(ctx) - if 1 != C.EVP_PKEY_decrypt_init(ctx) { - return nil, errors.New("decrypt: failed to init decryption") + if C.EVP_PKEY_decrypt_init(ctx) != 1 { + return nil, PopError() } var declen C.size_t - if 1 != C.EVP_PKEY_decrypt(ctx, nil, &declen, - (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) { - return nil, errors.New("decrypt: failed to determine decryption length") + if C.EVP_PKEY_decrypt(ctx, nil, &declen, (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) != 1 { + return nil, PopError() } dec := make([]byte, declen) - if 1 != C.EVP_PKEY_decrypt(ctx, - (*C.uchar)(unsafe.Pointer(&dec[0])), &declen, - (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data))) { - return nil, errors.New("decrypt: failed to finish decryption") + if C.EVP_PKEY_decrypt(ctx, (*C.uchar)(unsafe.Pointer(&dec[0])), &declen, (*C.uchar)(unsafe.Pointer(&data[0])), + C.size_t(len(data))) != 1 { + return nil, PopError() } return dec[:declen], nil } -func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, - err error) { +func (key *pKey) MarshalPKCS1PrivateKeyPEM() ([]byte, error) { bio := C.BIO_new(C.BIO_s_mem()) if bio == nil { - return nil, errors.New("failed to allocate memory BIO") + return nil, ErrMallocFailure } defer C.BIO_free(bio) @@ -334,105 +326,123 @@ func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, // to a PKCS8 key. if int(C.X_PEM_write_bio_PrivateKey_traditional(bio, key.key, nil, nil, C.int(0), nil, nil)) != 1 { - return nil, errors.New("failed dumping private key") + return nil, PopError() } - return ioutil.ReadAll(asAnyBio(bio)) + pem, err := io.ReadAll(asAnyBio(bio)) + if err != nil { + return nil, fmt.Errorf("failed to read bio data: %w", err) + } + + return pem, nil } -func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte, - err error) { +func (key *pKey) MarshalPKCS1PrivateKeyDER() ([]byte, error) { bio := C.BIO_new(C.BIO_s_mem()) if bio == nil { - return nil, errors.New("failed to allocate memory BIO") + return nil, ErrMallocFailure } defer C.BIO_free(bio) if int(C.i2d_PrivateKey_bio(bio, key.key)) != 1 { - return nil, errors.New("failed dumping private key der") + return nil, PopError() + } + + ret, err := io.ReadAll(asAnyBio(bio)) + if err != nil { + return nil, fmt.Errorf("failed to read bio data: %w", err) } - return ioutil.ReadAll(asAnyBio(bio)) + return ret, nil } -func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, - err error) { +func (key *pKey) MarshalPKIXPublicKeyPEM() ([]byte, error) { bio := C.BIO_new(C.BIO_s_mem()) if bio == nil { - return nil, errors.New("failed to allocate memory BIO") + return nil, ErrMallocFailure } defer C.BIO_free(bio) if int(C.PEM_write_bio_PUBKEY(bio, key.key)) != 1 { - return nil, errors.New("failed dumping public key pem") + return nil, PopError() + } + + ret, err := io.ReadAll(asAnyBio(bio)) + if err != nil { + return nil, fmt.Errorf("failed to read bio data: %w", err) } - return ioutil.ReadAll(asAnyBio(bio)) + return ret, nil } -func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte, - err error) { +func (key *pKey) MarshalPKIXPublicKeyDER() ([]byte, error) { bio := C.BIO_new(C.BIO_s_mem()) if bio == nil { - return nil, errors.New("failed to allocate memory BIO") + return nil, ErrMallocFailure } defer C.BIO_free(bio) if int(C.i2d_PUBKEY_bio(bio, key.key)) != 1 { - return nil, errors.New("failed dumping public key der") + return nil, PopError() } - return ioutil.ReadAll(asAnyBio(bio)) + ret, err := io.ReadAll(asAnyBio(bio)) + if err != nil { + return nil, fmt.Errorf("failed to read bio data: %w", err) + } + + return ret, nil } // LoadPrivateKeyFromPEM loads a private key from a PEM-encoded block. -func LoadPrivateKeyFromPEM(pem_block []byte) (PrivateKey, error) { - if len(pem_block) == 0 { - return nil, errors.New("empty pem block") +func LoadPrivateKeyFromPEM(pemBlock []byte) (PrivateKey, error) { + if len(pemBlock) == 0 { + return nil, ErrNoCert } - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pemBlock[0]), + C.int(len(pemBlock))) if bio == nil { - return nil, errors.New("failed creating bio") + return nil, ErrMallocFailure } defer C.BIO_free(bio) key := C.PEM_read_bio_PrivateKey(bio, nil, nil, nil) if key == nil { - return nil, errors.New("failed reading private key") + return nil, PopError() } - p := &pKey{key: key} - runtime.SetFinalizer(p, func(p *pKey) { + priKey := &pKey{key: key} + runtime.SetFinalizer(priKey, func(p *pKey) { C.X_EVP_PKEY_free(p.key) }) - if C.X_EVP_PKEY_is_sm2(p.key) == 1 { - if C.EVP_PKEY_set_alias_type(p.key, C.EVP_PKEY_SM2) != 1 { - return nil, errors.New("failed set alias type") + if C.X_EVP_PKEY_is_sm2(priKey.key) == 1 { + if C.EVP_PKEY_set_alias_type(priKey.key, C.EVP_PKEY_SM2) != 1 { + return nil, PopError() } } - return p, nil + return priKey, nil } // LoadPrivateKeyFromPEMWithPassword loads a private key from a PEM-encoded block. -func LoadPrivateKeyFromPEMWithPassword(pem_block []byte, password string) ( - PrivateKey, error) { - if len(pem_block) == 0 { - return nil, errors.New("empty pem block") - } - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) +func LoadPrivateKeyFromPEMWithPassword(pemBlock []byte, password string) ( + PrivateKey, error, +) { + if len(pemBlock) == 0 { + return nil, ErrNoKey + } + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pemBlock[0]), + C.int(len(pemBlock))) if bio == nil { - return nil, errors.New("failed creating bio") + return nil, ErrMallocFailure } defer C.BIO_free(bio) cs := C.CString(password) defer C.free(unsafe.Pointer(cs)) key := C.PEM_read_bio_PrivateKey(bio, nil, nil, unsafe.Pointer(cs)) if key == nil { - return nil, errors.New("failed reading private key") + return nil, PopError() } p := &pKey{key: key} @@ -443,20 +453,20 @@ func LoadPrivateKeyFromPEMWithPassword(pem_block []byte, password string) ( } // LoadPrivateKeyFromDER loads a private key from a DER-encoded block. -func LoadPrivateKeyFromDER(der_block []byte) (PrivateKey, error) { - if len(der_block) == 0 { - return nil, errors.New("empty der block") +func LoadPrivateKeyFromDER(derBlock []byte) (PrivateKey, error) { + if len(derBlock) == 0 { + return nil, ErrNoKey } - bio := C.BIO_new_mem_buf(unsafe.Pointer(&der_block[0]), - C.int(len(der_block))) + bio := C.BIO_new_mem_buf(unsafe.Pointer(&derBlock[0]), + C.int(len(derBlock))) if bio == nil { - return nil, errors.New("failed creating bio") + return nil, ErrMallocFailure } defer C.BIO_free(bio) key := C.d2i_PrivateKey_bio(bio, nil) if key == nil { - return nil, errors.New("failed reading private key der") + return nil, PopError() } p := &pKey{key: key} @@ -468,50 +478,52 @@ func LoadPrivateKeyFromDER(der_block []byte) (PrivateKey, error) { // LoadPrivateKeyFromPEMWidthPassword loads a private key from a PEM-encoded block. // Backwards-compatible with typo -func LoadPrivateKeyFromPEMWidthPassword(pem_block []byte, password string) ( - PrivateKey, error) { - return LoadPrivateKeyFromPEMWithPassword(pem_block, password) +func LoadPrivateKeyFromPEMWidthPassword(pemBlock []byte, password string) ( + PrivateKey, error, +) { + return LoadPrivateKeyFromPEMWithPassword(pemBlock, password) } // LoadPublicKeyFromPEM loads a public key from a PEM-encoded block. -func LoadPublicKeyFromPEM(pem_block []byte) (PublicKey, error) { - if len(pem_block) == 0 { - return nil, errors.New("empty pem block") +func LoadPublicKeyFromPEM(pemBlock []byte) (PublicKey, error) { + if len(pemBlock) == 0 { + return nil, ErrNoPubKey } - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) + + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pemBlock[0]), C.int(len(pemBlock))) if bio == nil { - return nil, errors.New("failed creating bio") + return nil, ErrMallocFailure } defer C.BIO_free(bio) key := C.PEM_read_bio_PUBKEY(bio, nil, nil, nil) if key == nil { - return nil, errors.New("failed reading public key der") + return nil, PopError() } p := &pKey{key: key} runtime.SetFinalizer(p, func(p *pKey) { C.X_EVP_PKEY_free(p.key) }) + return p, nil } // LoadPublicKeyFromDER loads a public key from a DER-encoded block. -func LoadPublicKeyFromDER(der_block []byte) (PublicKey, error) { - if len(der_block) == 0 { - return nil, errors.New("empty der block") +func LoadPublicKeyFromDER(derBlock []byte) (PublicKey, error) { + if len(derBlock) == 0 { + return nil, ErrNoPubKey } - bio := C.BIO_new_mem_buf(unsafe.Pointer(&der_block[0]), - C.int(len(der_block))) + bio := C.BIO_new_mem_buf(unsafe.Pointer(&derBlock[0]), + C.int(len(derBlock))) if bio == nil { - return nil, errors.New("failed creating bio") + return nil, ErrMallocFailure } defer C.BIO_free(bio) key := C.d2i_PUBKEY_bio(bio, nil) if key == nil { - return nil, errors.New("failed reading public key der") + return nil, PopError() } p := &pKey{key: key} @@ -521,24 +533,26 @@ func LoadPublicKeyFromDER(der_block []byte) (PublicKey, error) { return p, nil } -// GenerateRSAKey generates a new RSA private key with an exponent of 3. +// GenerateRSAKey generates a new RSA private key with an exponent of 65537. func GenerateRSAKey(bits int) (PrivateKey, error) { - return GenerateRSAKeyWithExponent(bits, 3) + defaultPubExp := 0x10001 + + return GenerateRSAKeyWithExponent(bits, defaultPubExp) } // GenerateRSAKeyWithExponent generates a new RSA private key. func GenerateRSAKeyWithExponent(bits int, exponent int) (PrivateKey, error) { rsa := C.RSA_generate_key(C.int(bits), C.ulong(exponent), nil, nil) if rsa == nil { - return nil, errors.New("failed to generate RSA key") + return nil, ErrMallocFailure } key := C.X_EVP_PKEY_new() if key == nil { - return nil, errors.New("failed to allocate EVP_PKEY") + return nil, ErrMallocFailure } if C.X_EVP_PKEY_assign_charp(key, C.EVP_PKEY_RSA, (*C.char)(unsafe.Pointer(rsa))) != 1 { C.X_EVP_PKEY_free(key) - return nil, errors.New("failed to assign RSA key") + return nil, PopError() } p := &pKey{key: key} runtime.SetFinalizer(p, func(p *pKey) { @@ -559,78 +573,76 @@ const ( // P-521: NIST/SECG curve over a 521 bit prime field Secp521r1 EllipticCurve = C.NID_secp521r1 // SM2: GB/T 32918-2017 - Sm2Curve EllipticCurve = C.NID_sm2 + SM2Curve EllipticCurve = C.NID_sm2 ) // GenerateECKey generates a new elliptic curve private key on the speicified // curve. func GenerateECKey(curve EllipticCurve) (PrivateKey, error) { - // Create context for parameter generation paramCtx := C.EVP_PKEY_CTX_new_id(C.EVP_PKEY_EC, nil) if paramCtx == nil { - return nil, errors.New("failed creating EC parameter generation context") + return nil, PopError() } defer C.EVP_PKEY_CTX_free(paramCtx) - if curve == Sm2Curve { + if curve == SM2Curve { if C.EVP_PKEY_keygen_init(paramCtx) != 1 { - return nil, errors.New("failed initializing EC key generation context") + return nil, PopError() } } else { - // Intialize the parameter generation if int(C.EVP_PKEY_paramgen_init(paramCtx)) != 1 { - return nil, errors.New("failed initializing EC parameter generation context") + return nil, PopError() } } // Set curve in EC parameter generation context if int(C.X_EVP_PKEY_CTX_set_ec_paramgen_curve_nid(paramCtx, C.int(curve))) != 1 { - return nil, errors.New("failed setting curve in EC parameter generation context") + return nil, PopError() } - var privKey *C.EVP_PKEY + var key *C.EVP_PKEY - if curve == Sm2Curve { - if int(C.EVP_PKEY_keygen(paramCtx, &privKey)) != 1 { - return nil, errors.New("failed generating EC private key") + if curve == SM2Curve { + if int(C.EVP_PKEY_keygen(paramCtx, &key)) != 1 { + return nil, PopError() } } else { // Create parameter object var params *C.EVP_PKEY if int(C.EVP_PKEY_paramgen(paramCtx, ¶ms)) != 1 { - return nil, errors.New("failed creating EC key generation parameters") + return nil, PopError() } defer C.EVP_PKEY_free(params) // Create context for the key generation keyCtx := C.EVP_PKEY_CTX_new(params, nil) if keyCtx == nil { - return nil, errors.New("failed creating EC key generation context") + return nil, PopError() } defer C.EVP_PKEY_CTX_free(keyCtx) if int(C.EVP_PKEY_keygen_init(keyCtx)) != 1 { - return nil, errors.New("failed initializing EC key generation context") + return nil, PopError() } - if int(C.EVP_PKEY_keygen(keyCtx, &privKey)) != 1 { - return nil, errors.New("failed generating EC private key") + if int(C.EVP_PKEY_keygen(keyCtx, &key)) != 1 { + return nil, PopError() } } - p := &pKey{key: privKey} - runtime.SetFinalizer(p, func(p *pKey) { + privKey := &pKey{key: key} + runtime.SetFinalizer(privKey, func(p *pKey) { C.X_EVP_PKEY_free(p.key) }) - if curve == Sm2Curve { - if C.EVP_PKEY_set_alias_type(p.key, C.EVP_PKEY_SM2) != 1 { - return nil, errors.New("failed set alias type") + if curve == SM2Curve { + if C.EVP_PKEY_set_alias_type(privKey.key, C.EVP_PKEY_SM2) != 1 { + return nil, PopError() } } - return p, nil + return privKey, nil } // GenerateED25519Key generates a Ed25519 key @@ -638,17 +650,17 @@ func GenerateED25519Key() (PrivateKey, error) { // Key context keyCtx := C.EVP_PKEY_CTX_new_id(C.X_EVP_PKEY_ED25519, nil) if keyCtx == nil { - return nil, errors.New("failed creating EC parameter generation context") + return nil, PopError() } defer C.EVP_PKEY_CTX_free(keyCtx) // Generate the key var privKey *C.EVP_PKEY if int(C.EVP_PKEY_keygen_init(keyCtx)) != 1 { - return nil, errors.New("failed initializing ED25519 key generation context") + return nil, PopError() } if int(C.EVP_PKEY_keygen(keyCtx, &privKey)) != 1 { - return nil, errors.New("failed generating ED25519 private key") + return nil, PopError() } p := &pKey{key: privKey} diff --git a/crypto/key_test.go b/crypto/key_test.go index aeaac2b..a872ff0 100644 --- a/crypto/key_test.go +++ b/crypto/key_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package crypto +package crypto_test import ( "bytes" @@ -22,12 +22,14 @@ import ( "crypto/x509" "encoding/hex" pem_pkg "encoding/pem" - "io/ioutil" + "os" "testing" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) -var ( - certBytes = []byte(`-----BEGIN CERTIFICATE----- +const ( + certBytes = `-----BEGIN CERTIFICATE----- MIIExjCCAy6gAwIBAgIRAMqZUO0eR6sVZ3A8iG8bJK8wDQYJKoZIhvcNAQELBQAw ezEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMSgwJgYDVQQLDB90b21z YXd5ZXJAQi1GRzc5TUw3SC0wNDQ4LmxvY2FsMS8wLQYDVQQDDCZta2NlcnQgdG9t @@ -55,8 +57,8 @@ UjxAjLXvWmij6ilpMADnLQA0SH6s+9E2Aa5LTpEMDqXORcu+sq5/m3RuDtVxuYdU HNnVAmIdTLKC9CWnRfDxH8zPgIr/L8Yhdw92YST8hNqGQHeR0qoBcKYMHkpH6Ay4 yuKERO5LaAmjoXJW3n5Zal6jogf3wpiV1o4= -----END CERTIFICATE----- -`) - keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- +` + keyBytes = `-----BEGIN RSA PRIVATE KEY----- MIIG5AIBAAKCAYEAwoj05U9j+Y8xGs8xV9n0+6U6KJSS6eMn9xL4CW9qoxbpmNXw ZNqujLWH0LrceF34X6vqG3U6mrhQKt2g4ywWGOmxQk9OuJqwHrVFPZyheiS4zxwU D/CnsZ5cq1aOKH/PamWnaafNylBdj33o3CQXZBRkiYNandJGPjQsVjfI6I08y2RO @@ -95,14 +97,14 @@ JBfFkFXDvcbaYmpcOVCS1susPrPr8rgIm+vK6X+UoWE1RcCMGQMKObQIDGpE4IWe TDoukqQ8peoffk6mtiCnph9Cl2uqAgmmX+GyunEMIdF/ySG0CCcfz180GsQCucax +AxW2R7NJMAHvfeaYoLtSMYEVTS8sSpuIbRTfGuxbmMOD8a03gU6AA== -----END RSA PRIVATE KEY----- -`) - prime256v1KeyBytes = []byte(`-----BEGIN EC PRIVATE KEY----- +` + prime256v1KeyBytes = `-----BEGIN EC PRIVATE KEY----- MHcCAQEEIB/XL0zZSsAu+IQF1AI/nRneabb2S126WFlvvhzmYr1KoAoGCCqGSM49 AwEHoUQDQgAESSFGWwF6W1hoatKGPPorh4+ipyk0FqpiWdiH+4jIiU39qtOeZGSh 1QgSbzfdHxvoYI0FXM+mqE7wec0kIvrrHw== -----END EC PRIVATE KEY----- -`) - prime256v1CertBytes = []byte(`-----BEGIN CERTIFICATE----- +` + prime256v1CertBytes = `-----BEGIN CERTIFICATE----- MIIChTCCAiqgAwIBAgIJAOQII2LQl4uxMAoGCCqGSM49BAMCMIGcMQswCQYDVQQG EwJVUzEPMA0GA1UECAwGS2Fuc2FzMRAwDgYDVQQHDAdOb3doZXJlMR8wHQYDVQQK DBZGYWtlIENlcnRpZmljYXRlcywgSW5jMUkwRwYDVQQDDEBhMWJkZDVmZjg5ZjQy @@ -118,14 +120,14 @@ FhlGM1wzvusyGrm26Vrbqm4wDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNJ ADBGAiEA6PWNjm4B6zs3Wcha9qyDdfo1ILhHfk9rZEAGrnfyc2UCIQD1IDVJUkI4 J/QVoOtP5DOdRPs/3XFy0Bk0qH+Uj5D7LQ== -----END CERTIFICATE----- -`) - sm2KeyBytes = []byte(`-----BEGIN EC PRIVATE KEY----- +` + sm2KeyBytes = `-----BEGIN EC PRIVATE KEY----- MHcCAQEEIJ0I4yR5ezlVWygUi7+NNipNJSBqUjaCopitIJMU1nlSoAoGCCqBHM9V AYItoUQDQgAEMbiGjBxkDrC1rwuVlIC/6fbGdnaKxj2/Lkv9EcOLKv3WFuFi1eae UvQSkNcRMdaAixpM+RKQ+Cp6Z3szJUr0jQ== -----END EC PRIVATE KEY----- -`) - sm2CertBytes = []byte(`-----BEGIN CERTIFICATE----- +` + sm2CertBytes = `-----BEGIN CERTIFICATE----- MIIDKjCCAs+gAwIBAgIQILmubp7njmhGt3wZLFwx6jAKBggqgRzPVQGDdTBtMQsw CQYDVQQGEwJDTjELMAkGA1UECAwCSFoxDDAKBgNVBAoMA2FsaTEMMAoGA1UECwwD YW50MRcwFQYDVQQDDA53d3cubWlkZGxlLmNvbTEcMBoGCSqGSIb3DQEJARYNdGVz @@ -144,8 +146,8 @@ A1UdJQQMMAoGCCsGAQUFBwMBMCMGA1UdEQQcMBqCCmFsaXBheS5jb22CDCouYWxp cGF5LmNvbTAKBggqgRzPVQGDdQNJADBGAiEAmuMCuZKaF3zVYc1T6DGGi0+hmMuZ jpH7uznwqix7GJsCIQCOjB/iG+WxOvUz//t//Ru1QnVivDaCEQXkW2dXyX+fWg== -----END CERTIFICATE----- -`) - ed25519CertBytes = []byte(`-----BEGIN CERTIFICATE----- +` + ed25519CertBytes = `-----BEGIN CERTIFICATE----- MIIBIzCB1gIUd0UUPX+qHrSKSVN9V/A3F1Eeti4wBQYDK2VwMDYxCzAJBgNVBAYT AnVzMQ0wCwYDVQQKDARDU0NPMRgwFgYDVQQDDA9lZDI1NTE5X3Jvb3RfY2EwHhcN MTgwODE3MDMzNzQ4WhcNMjgwODE0MDMzNzQ4WjAzMQswCQYDVQQGEwJ1czENMAsG @@ -154,25 +156,29 @@ zzlBcpjdbvzV0BRoaSiJKxbU6GnFeAELA0cHWR0wBQYDK2VwA0EAbfUJ7L7v3GDq Gv7R90wQ/OKAc+o0q9eOrD6KRYDBhvlnMKqTMRVucnHXfrd5Rhmf4yHTvFTOhwmO t/hpmISAAA== -----END CERTIFICATE----- -`) - ed25519KeyBytes = []byte(`-----BEGIN PRIVATE KEY----- +` + ed25519KeyBytes = `-----BEGIN PRIVATE KEY----- MC4CAQAwBQYDK2VwBCIEIL3QVwyuusKuLgZwZn356UHk9u1REGHbNTLtFMPKNQSb -----END PRIVATE KEY----- -`) +` ) func TestMarshal(t *testing.T) { - key, err := LoadPrivateKeyFromPEM(keyBytes) + t.Parallel() + + _, err := crypto.LoadPrivateKeyFromPEM([]byte(keyBytes)) if err != nil { t.Error(err) } - cert, err := LoadCertificateFromPEM(certBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(certBytes)) if err != nil { t.Error(err) } - privateBlock, _ := pem_pkg.Decode(keyBytes) - key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) + privateBlock, _ := pem_pkg.Decode([]byte(keyBytes)) + + key, err := crypto.LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { t.Error(err) } @@ -181,9 +187,18 @@ func TestMarshal(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, certBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", certBytes, 0644) + + if !bytes.Equal(pem, []byte(certBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(certBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid cert pem bytes") } @@ -191,42 +206,63 @@ func TestMarshal(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, keyBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", keyBytes, 0644) + + if !bytes.Equal(pem, []byte(keyBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(keyBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid private key pem bytes") } - tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) + + tlsCert, err := tls.X509KeyPair([]byte(certBytes), []byte(keyBytes)) if err != nil { t.Error(err) } - tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey) + + tlsKey, ok := tlsCert.PrivateKey.(*rsa.PrivateKey) if !ok { t.Error("FASDFASDF") } - _ = tls_key der, err := key.MarshalPKCS1PrivateKeyDER() if err != nil { t.Error(err) } - tls_der := x509.MarshalPKCS1PrivateKey(tls_key) - if !bytes.Equal(der, tls_der) { + + tlsDer := x509.MarshalPKCS1PrivateKey(tlsKey) + if !bytes.Equal(der, tlsDer) { t.Errorf("invalid private key der bytes: %s\n v.s. %s\n", - hex.Dump(der), hex.Dump(tls_der)) + hex.Dump(der), hex.Dump(tlsDer)) } der, err = key.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey) + + tlsDer, err = x509.MarshalPKIXPublicKey(&tlsKey.PublicKey) if err != nil { t.Error(err) } - if !bytes.Equal(der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + + if !bytes.Equal(der, tlsDer) { + err := os.WriteFile("generated", []byte(hex.Dump(der)), 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(hex.Dump(tlsDer)), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key der bytes") } @@ -234,75 +270,110 @@ func TestMarshal(t *testing.T) { if err != nil { t.Error(err) } - tls_pem := pem_pkg.EncodeToMemory(&pem_pkg.Block{ - Type: "PUBLIC KEY", Bytes: tls_der}) - if !bytes.Equal(pem, tls_pem) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", tls_pem, 0644) + + tlsPem := pem_pkg.EncodeToMemory(&pem_pkg.Block{ + Type: "PUBLIC KEY", Headers: nil, Bytes: tlsDer, + }) + if !bytes.Equal(pem, tlsPem) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", tlsPem, 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key pem bytes") } - loaded_pubkey_from_pem, err := LoadPublicKeyFromPEM(pem) + pubkeyFromPem, err := crypto.LoadPublicKeyFromPEM(pem) if err != nil { t.Error(err) } - loaded_pubkey_from_der, err := LoadPublicKeyFromDER(der) + pubkeyFromDer, err := crypto.LoadPublicKeyFromDER(der) if err != nil { t.Error(err) } - new_der_from_pem, err := loaded_pubkey_from_pem.MarshalPKIXPublicKeyDER() + newDerFromPem, err := pubkeyFromPem.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - new_der_from_der, err := loaded_pubkey_from_der.MarshalPKIXPublicKeyDER() + newDerFromDer, err := pubkeyFromDer.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - if !bytes.Equal(new_der_from_der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + if !bytes.Equal(newDerFromDer, tlsDer) { + err := os.WriteFile("generated", []byte(hex.Dump(newDerFromDer)), 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(hex.Dump(tlsDer)), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key der bytes") } - if !bytes.Equal(new_der_from_pem, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_pem)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + if !bytes.Equal(newDerFromPem, tlsDer) { + err := os.WriteFile("generated", []byte(hex.Dump(newDerFromPem)), 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(hex.Dump(tlsDer)), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key der bytes") } } func TestGenerate(t *testing.T) { - key, err := GenerateRSAKey(2048) + t.Parallel() + + key, err := crypto.GenerateRSAKey(2048) if err != nil { t.Error(err) } + _, err = key.MarshalPKIXPublicKeyPEM() if err != nil { t.Error(err) } + _, err = key.MarshalPKCS1PrivateKeyPEM() if err != nil { t.Error(err) } - _, err = GenerateRSAKeyWithExponent(1024, 65537) + + _, err = crypto.GenerateRSAKeyWithExponent(1024, 65537) if err != nil { t.Error(err) } } func TestGenerateEC(t *testing.T) { - key, err := GenerateECKey(Prime256v1) + t.Parallel() + + key, err := crypto.GenerateECKey(crypto.Prime256v1) if err != nil { t.Error(err) } + _, err = key.MarshalPKIXPublicKeyPEM() if err != nil { t.Error(err) } + _, err = key.MarshalPKCS1PrivateKeyPEM() if err != nil { t.Error(err) @@ -310,18 +381,22 @@ func TestGenerateEC(t *testing.T) { } func TestGenerateEd25519(t *testing.T) { - if !ed25519_support { + t.Parallel() + + if !crypto.SupportEd25519() { t.SkipNow() } - key, err := GenerateED25519Key() + key, err := crypto.GenerateED25519Key() if err != nil { t.Error(err) } + _, err = key.MarshalPKIXPublicKeyPEM() if err != nil { t.Error(err) } + _, err = key.MarshalPKCS1PrivateKeyPEM() if err != nil { t.Error(err) @@ -329,17 +404,22 @@ func TestGenerateEd25519(t *testing.T) { } func TestSign(t *testing.T) { - key, _ := GenerateRSAKey(1024) + t.Parallel() + + key, _ := crypto.GenerateRSAKey(1024) data := []byte("the quick brown fox jumps over the lazy dog") - _, err := key.SignPKCS1v15(SHA1_Method, data) + + _, err := key.SignPKCS1v15(crypto.SHA1Method(), data) if err != nil { t.Error(err) } - _, err = key.SignPKCS1v15(SHA256_Method, data) + + _, err = key.SignPKCS1v15(crypto.SHA256Method(), data) if err != nil { t.Error(err) } - _, err = key.SignPKCS1v15(SHA512_Method, data) + + _, err = key.SignPKCS1v15(crypto.SHA512Method(), data) if err != nil { t.Error(err) } @@ -348,19 +428,22 @@ func TestSign(t *testing.T) { func TestSignEC(t *testing.T) { t.Parallel() - key, err := GenerateECKey(Prime256v1) + key, err := crypto.GenerateECKey(crypto.Prime256v1) if err != nil { t.Error(err) } + data := []byte("the quick brown fox jumps over the lazy dog") t.Run("sha1", func(t *testing.T) { t.Parallel() - sig, err := key.SignPKCS1v15(SHA1_Method, data) + + sig, err := key.SignPKCS1v15(crypto.SHA1Method(), data) if err != nil { t.Error(err) } - err = key.VerifyPKCS1v15(SHA1_Method, data, sig) + + err = key.VerifyPKCS1v15(crypto.SHA1Method(), data, sig) if err != nil { t.Error(err) } @@ -368,11 +451,13 @@ func TestSignEC(t *testing.T) { t.Run("sha256", func(t *testing.T) { t.Parallel() - sig, err := key.SignPKCS1v15(SHA256_Method, data) + + sig, err := key.SignPKCS1v15(crypto.SHA256Method(), data) if err != nil { t.Error(err) } - err = key.VerifyPKCS1v15(SHA256_Method, data, sig) + + err = key.VerifyPKCS1v15(crypto.SHA256Method(), data, sig) if err != nil { t.Error(err) } @@ -380,11 +465,13 @@ func TestSignEC(t *testing.T) { t.Run("sha512", func(t *testing.T) { t.Parallel() - sig, err := key.SignPKCS1v15(SHA512_Method, data) + + sig, err := key.SignPKCS1v15(crypto.SHA512Method(), data) if err != nil { t.Error(err) } - err = key.VerifyPKCS1v15(SHA512_Method, data, sig) + + err = key.VerifyPKCS1v15(crypto.SHA512Method(), data, sig) if err != nil { t.Error(err) } @@ -394,19 +481,22 @@ func TestSignEC(t *testing.T) { func TestSignSM2(t *testing.T) { t.Parallel() - key, err := GenerateECKey(Sm2Curve) + key, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { t.Error(err) } + data := []byte("the quick brown fox jumps over the lazy dog") t.Run("sm2", func(t *testing.T) { t.Parallel() - sig, err := key.SignPKCS1v15(SM3_Method, data) + + sig, err := key.SignPKCS1v15(crypto.SM3Method(), data) if err != nil { t.Error(err) } - err = key.VerifyPKCS1v15(SM3_Method, data, sig) + + err = key.VerifyPKCS1v15(crypto.SM3Method(), data, sig) if err != nil { t.Error(err) } @@ -414,20 +504,22 @@ func TestSignSM2(t *testing.T) { } func TestSignED25519(t *testing.T) { - if !ed25519_support { + t.Parallel() + + if !crypto.SupportEd25519() { t.SkipNow() } - t.Parallel() - - key, err := GenerateED25519Key() + key, err := crypto.GenerateED25519Key() if err != nil { t.Error(err) } + data := []byte("the quick brown fox jumps over the lazy dog") t.Run("new", func(t *testing.T) { t.Parallel() + sig, err := key.SignPKCS1v15(nil, data) if err != nil { t.Error(err) @@ -437,22 +529,25 @@ func TestSignED25519(t *testing.T) { if err != nil { t.Error(err) } - }) } func TestMarshalEC(t *testing.T) { - key, err := LoadPrivateKeyFromPEM(prime256v1KeyBytes) + t.Parallel() + + _, err := crypto.LoadPrivateKeyFromPEM([]byte(prime256v1KeyBytes)) if err != nil { t.Error(err) } - cert, err := LoadCertificateFromPEM(prime256v1CertBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(prime256v1CertBytes)) if err != nil { t.Error(err) } - privateBlock, _ := pem_pkg.Decode(prime256v1KeyBytes) - key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) + privateBlock, _ := pem_pkg.Decode([]byte(prime256v1KeyBytes)) + + key, err := crypto.LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { t.Error(err) } @@ -461,9 +556,18 @@ func TestMarshalEC(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, prime256v1CertBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", prime256v1CertBytes, 0644) + + if !bytes.Equal(pem, []byte(prime256v1CertBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(prime256v1CertBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid cert pem bytes") } @@ -471,45 +575,69 @@ func TestMarshalEC(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, prime256v1KeyBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", prime256v1KeyBytes, 0644) + + if !bytes.Equal(pem, []byte(prime256v1KeyBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(prime256v1KeyBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid private key pem bytes") } - tls_cert, err := tls.X509KeyPair(prime256v1CertBytes, prime256v1KeyBytes) + + tlsCert, err := tls.X509KeyPair([]byte(prime256v1CertBytes), []byte(prime256v1KeyBytes)) if err != nil { t.Error(err) } - tls_key, ok := tls_cert.PrivateKey.(*ecdsa.PrivateKey) + + tlsKey, ok := tlsCert.PrivateKey.(*ecdsa.PrivateKey) if !ok { t.Error("FASDFASDF") } - _ = tls_key + + _ = tlsKey der, err := key.MarshalPKCS1PrivateKeyDER() if err != nil { t.Error(err) } - tls_der, err := x509.MarshalECPrivateKey(tls_key) + + tlsDer, err := x509.MarshalECPrivateKey(tlsKey) if err != nil { t.Error(err) } - if !bytes.Equal(der, tls_der) { + + if !bytes.Equal(der, tlsDer) { t.Errorf("invalid private key der bytes: %s\n v.s. %s\n", - hex.Dump(der), hex.Dump(tls_der)) + hex.Dump(der), hex.Dump(tlsDer)) } der, err = key.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey) + + tlsDer, err = x509.MarshalPKIXPublicKey(&tlsKey.PublicKey) if err != nil { t.Error(err) } - if !bytes.Equal(der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + + if !bytes.Equal(der, tlsDer) { + err := os.WriteFile("generated", []byte(hex.Dump(der)), 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(hex.Dump(tlsDer)), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key der bytes") } @@ -517,59 +645,89 @@ func TestMarshalEC(t *testing.T) { if err != nil { t.Error(err) } - tls_pem := pem_pkg.EncodeToMemory(&pem_pkg.Block{ - Type: "PUBLIC KEY", Bytes: tls_der}) - if !bytes.Equal(pem, tls_pem) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", tls_pem, 0644) + + tlsPem := pem_pkg.EncodeToMemory(&pem_pkg.Block{ + Type: "PUBLIC KEY", Headers: nil, Bytes: tlsDer, + }) + if !bytes.Equal(pem, tlsPem) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", tlsPem, 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key pem bytes") } - loaded_pubkey_from_pem, err := LoadPublicKeyFromPEM(pem) + pubkeyFromPem, err := crypto.LoadPublicKeyFromPEM(pem) if err != nil { t.Error(err) } - loaded_pubkey_from_der, err := LoadPublicKeyFromDER(der) + pubkeyFromDer, err := crypto.LoadPublicKeyFromDER(der) if err != nil { t.Error(err) } - new_der_from_pem, err := loaded_pubkey_from_pem.MarshalPKIXPublicKeyDER() + newDerFromPem, err := pubkeyFromPem.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - new_der_from_der, err := loaded_pubkey_from_der.MarshalPKIXPublicKeyDER() + newDerFromDer, err := pubkeyFromDer.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - if !bytes.Equal(new_der_from_der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + if !bytes.Equal(newDerFromDer, tlsDer) { + err := os.WriteFile("generated", []byte(hex.Dump(newDerFromDer)), 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(hex.Dump(tlsDer)), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key der bytes") } - if !bytes.Equal(new_der_from_pem, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_pem)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + if !bytes.Equal(newDerFromPem, tlsDer) { + err := os.WriteFile("generated", []byte(hex.Dump(newDerFromPem)), 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(hex.Dump(tlsDer)), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid public key der bytes") } } func TestMarshalSM2(t *testing.T) { - key, err := LoadPrivateKeyFromPEM(sm2KeyBytes) + t.Parallel() + + _, err := crypto.LoadPrivateKeyFromPEM([]byte(sm2KeyBytes)) if err != nil { t.Error(err) } - cert, err := LoadCertificateFromPEM(sm2CertBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(sm2CertBytes)) if err != nil { t.Error(err) } - privateBlock, _ := pem_pkg.Decode(sm2KeyBytes) - key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) + privateBlock, _ := pem_pkg.Decode([]byte(sm2KeyBytes)) + + key, err := crypto.LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { t.Error(err) } @@ -578,9 +736,18 @@ func TestMarshalSM2(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, sm2CertBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", sm2CertBytes, 0644) + + if !bytes.Equal(pem, []byte(sm2CertBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(sm2CertBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid cert pem bytes") } @@ -588,29 +755,42 @@ func TestMarshalSM2(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, sm2KeyBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", sm2KeyBytes, 0644) + + if !bytes.Equal(pem, []byte(sm2KeyBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(sm2KeyBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid private key pem bytes") } } func TestMarshalEd25519(t *testing.T) { - if !ed25519_support { + t.Parallel() + + if !crypto.SupportEd25519() { t.SkipNow() } - key, err := LoadPrivateKeyFromPEM(ed25519KeyBytes) + _, err := crypto.LoadPrivateKeyFromPEM([]byte(ed25519KeyBytes)) if err != nil { t.Error(err) } - cert, err := LoadCertificateFromPEM(ed25519CertBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(ed25519CertBytes)) if err != nil { t.Error(err) } - privateBlock, _ := pem_pkg.Decode(ed25519KeyBytes) - key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) + privateBlock, _ := pem_pkg.Decode([]byte(ed25519KeyBytes)) + + key, err := crypto.LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { t.Error(err) } @@ -619,23 +799,32 @@ func TestMarshalEd25519(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(pem, ed25519CertBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", ed25519CertBytes, 0644) + + if !bytes.Equal(pem, []byte(ed25519CertBytes)) { + err := os.WriteFile("generated", pem, 0o600) + if err != nil { + t.Error(err) + } + + err = os.WriteFile("hardcoded", []byte(ed25519CertBytes), 0o600) + if err != nil { + t.Error(err) + } + t.Error("invalid cert pem bytes") } - pem, err = key.MarshalPKCS1PrivateKeyPEM() + _, err = key.MarshalPKCS1PrivateKeyPEM() if err != nil { t.Error(err) } - der, err := key.MarshalPKCS1PrivateKeyDER() + _, err = key.MarshalPKCS1PrivateKeyDER() if err != nil { t.Error(err) } - der, err = key.MarshalPKIXPublicKeyDER() + der, err := key.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } @@ -645,22 +834,22 @@ func TestMarshalEd25519(t *testing.T) { t.Error(err) } - loaded_pubkey_from_pem, err := LoadPublicKeyFromPEM(pem) + pubkeyFromPem, err := crypto.LoadPublicKeyFromPEM(pem) if err != nil { t.Error(err) } - loaded_pubkey_from_der, err := LoadPublicKeyFromDER(der) + pubkeyFromDer, err := crypto.LoadPublicKeyFromDER(der) if err != nil { t.Error(err) } - _, err = loaded_pubkey_from_pem.MarshalPKIXPublicKeyDER() + _, err = pubkeyFromPem.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } - _, err = loaded_pubkey_from_der.MarshalPKIXPublicKeyDER() + _, err = pubkeyFromDer.MarshalPKIXPublicKeyDER() if err != nil { t.Error(err) } diff --git a/crypto/mapping.go b/crypto/mapping.go index 778cbfd..a6de84f 100644 --- a/crypto/mapping.go +++ b/crypto/mapping.go @@ -14,14 +14,14 @@ package crypto +// #include +import "C" + import ( "sync" "unsafe" ) -// #include -import "C" - type mapping struct { lock sync.Mutex values map[token]unsafe.Pointer diff --git a/crypto/md5/md5.go b/crypto/md5/md5.go index c8e62b1..342bbd1 100644 --- a/crypto/md5/md5.go +++ b/crypto/md5/md5.go @@ -15,14 +15,10 @@ package md5 // #include "../shim.h" -// #cgo linux LDFLAGS: -lcrypto -// #cgo darwin LDFLAGS: -lcrypto -// #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto import "C" import ( - "errors" + "fmt" "hash" "runtime" "unsafe" @@ -31,8 +27,8 @@ import ( ) const ( - MD5_DIGEST_LENGTH = 16 - MD5_CBLOCK = 64 + MDSize = 16 + md5Cblock = 64 ) var _ hash.Hash = new(MD5) @@ -45,30 +41,34 @@ type MD5 struct { func New() (*MD5, error) { return NewWithEngine(nil) } func NewWithEngine(e *crypto.Engine) (*MD5, error) { - h, err := newMD5WithEngine(e) + hash, err := newMD5WithEngine(e) if err != nil { return nil, err } - h.Reset() - return h, nil + + hash.Reset() + + return hash, nil } func newMD5WithEngine(e *crypto.Engine) (*MD5, error) { - hash := &MD5{engine: e} + hash := &MD5{ctx: nil, engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { - return nil, errors.New("openssl: md5: unable to allocate ctx") + return nil, fmt.Errorf("failed to create md ctx: %w", crypto.ErrMallocFailure) } + runtime.SetFinalizer(hash, func(hash *MD5) { hash.Close() }) + return hash, nil } func (s *MD5) BlockSize() int { - return MD5_CBLOCK + return md5Cblock } func (s *MD5) Size() int { - return MD5_DIGEST_LENGTH + return MDSize } func (s *MD5) Close() { @@ -82,14 +82,16 @@ func (s *MD5) Reset() { C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md5(), (*C.ENGINE)(s.engine.Engine())) } -func (s *MD5) Write(p []byte) (n int, err error) { - if len(p) == 0 { +func (s *MD5) Write(data []byte) (int, error) { + if len(data) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), C.size_t(len(p))) { - return 0, errors.New("openssl: md5: cannot update digest") + + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) != 1 { + return 0, fmt.Errorf("failed to update digest: %w", crypto.PopError()) } - return len(p), nil + + return len(data), nil } func (s *MD5) Sum(in []byte) []byte { @@ -106,19 +108,19 @@ func (s *MD5) Sum(in []byte) []byte { return append(in, result[:]...) } -func (s *MD5) checkSum() (result [MD5_DIGEST_LENGTH]byte) { +func (s *MD5) checkSum() [MDSize]byte { + var result [MDSize]byte + C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) + return result } -func Sum(data []byte) (result [MD5_DIGEST_LENGTH]byte) { - C.X_EVP_Digest( - unsafe.Pointer(&data[0]), - C.size_t(len(data)), - (*C.uchar)(unsafe.Pointer(&result[0])), - nil, - C.X_EVP_md5(), - nil, - ) - return +func Sum(data []byte) [MDSize]byte { + var result [MDSize]byte + + C.X_EVP_Digest(unsafe.Pointer(&data[0]), C.size_t(len(data)), (*C.uchar)(unsafe.Pointer(&result[0])), nil, + C.X_EVP_md5(), nil) + + return result } diff --git a/crypto/md5/md5_test.go b/crypto/md5/md5_test.go index 99a47dd..d1dc12d 100644 --- a/crypto/md5/md5_test.go +++ b/crypto/md5/md5_test.go @@ -12,27 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -package md5 +package md5_test import ( "crypto/md5" "crypto/rand" "io" "testing" + + tsMD5 "github.com/tongsuo-project/tongsuo-go-sdk/crypto/md5" ) func TestMD5(t *testing.T) { + t.Parallel() + for i := 0; i < 100; i++ { buf := make([]byte, 10*1024-i) if _, err := io.ReadFull(rand.Reader, buf); err != nil { t.Fatal(err) } - var got, expected [MD5_DIGEST_LENGTH]byte + var got, expected [tsMD5.MDSize]byte s := md5.Sum(buf) - got = Sum(buf) - copy(expected[:], s[:MD5_DIGEST_LENGTH]) + got = tsMD5.Sum(buf) + + copy(expected[:], s[:tsMD5.MDSize]) if expected != got { t.Fatalf("exp:%x got:%x", expected, got) @@ -41,15 +46,19 @@ func TestMD5(t *testing.T) { } func TestMD5Writer(t *testing.T) { - ohash, err := New() + t.Parallel() + + ohash, err := tsMD5.New() if err != nil { t.Fatal(err) } + hash := md5.New() for i := 0; i < 100; i++ { ohash.Reset() hash.Reset() + buf := make([]byte, 10*1024-i) if _, err := io.ReadFull(rand.Reader, buf); err != nil { t.Fatal(err) @@ -58,11 +67,12 @@ func TestMD5Writer(t *testing.T) { if _, err := ohash.Write(buf); err != nil { t.Fatal(err) } + if _, err := hash.Write(buf); err != nil { t.Fatal(err) } - var got, exp [MD5_DIGEST_LENGTH]byte + var got, exp [tsMD5.MDSize]byte hash.Sum(exp[:0]) ohash.Sum(got[:0]) @@ -76,19 +86,23 @@ func TestMD5Writer(t *testing.T) { type md5func func([]byte) func benchmarkMD5(b *testing.B, length int64, fn md5func) { + b.Helper() + buf := make([]byte, length) if _, err := io.ReadFull(rand.Reader, buf); err != nil { b.Fatal(err) } + b.SetBytes(length) b.ResetTimer() + for i := 0; i < b.N; i++ { fn(buf) } } func BenchmarkMD5Large_openssl(b *testing.B) { - benchmarkMD5(b, 1024*1024, func(buf []byte) { Sum(buf) }) + benchmarkMD5(b, 1024*1024, func(buf []byte) { tsMD5.Sum(buf) }) } func BenchmarkMD5Large_stdlib(b *testing.B) { @@ -96,7 +110,7 @@ func BenchmarkMD5Large_stdlib(b *testing.B) { } func BenchmarkMD5Normal_openssl(b *testing.B) { - benchmarkMD5(b, 1024, func(buf []byte) { Sum(buf) }) + benchmarkMD5(b, 1024, func(buf []byte) { tsMD5.Sum(buf) }) } func BenchmarkMD5Normal_stdlib(b *testing.B) { @@ -104,7 +118,7 @@ func BenchmarkMD5Normal_stdlib(b *testing.B) { } func BenchmarkMD5Small_openssl(b *testing.B) { - benchmarkMD5(b, 1, func(buf []byte) { Sum(buf) }) + benchmarkMD5(b, 1, func(buf []byte) { tsMD5.Sum(buf) }) } func BenchmarkMD5Small_stdlib(b *testing.B) { diff --git a/crypto/nid.go b/crypto/nid.go index c4519e9..7af2e05 100644 --- a/crypto/nid.go +++ b/crypto/nid.go @@ -17,195 +17,195 @@ package crypto type NID int const ( - NID_undef NID = 0 - NID_rsadsi NID = 1 - NID_pkcs NID = 2 - NID_md2 NID = 3 - NID_md5 NID = 4 - NID_rc4 NID = 5 - NID_rsaEncryption NID = 6 - NID_md2WithRSAEncryption NID = 7 - NID_md5WithRSAEncryption NID = 8 - NID_pbeWithMD2AndDES_CBC NID = 9 - NID_pbeWithMD5AndDES_CBC NID = 10 - NID_X500 NID = 11 - NID_X509 NID = 12 - NID_commonName NID = 13 - NID_countryName NID = 14 - NID_localityName NID = 15 - NID_stateOrProvinceName NID = 16 - NID_organizationName NID = 17 - NID_organizationalUnitName NID = 18 - NID_rsa NID = 19 - NID_pkcs7 NID = 20 - NID_pkcs7_data NID = 21 - NID_pkcs7_signed NID = 22 - NID_pkcs7_enveloped NID = 23 - NID_pkcs7_signedAndEnveloped NID = 24 - NID_pkcs7_digest NID = 25 - NID_pkcs7_encrypted NID = 26 - NID_pkcs3 NID = 27 - NID_dhKeyAgreement NID = 28 - NID_des_ecb NID = 29 - NID_des_cfb64 NID = 30 - NID_des_cbc NID = 31 - NID_des_ede NID = 32 - NID_des_ede3 NID = 33 - NID_idea_cbc NID = 34 - NID_idea_cfb64 NID = 35 - NID_idea_ecb NID = 36 - NID_rc2_cbc NID = 37 - NID_rc2_ecb NID = 38 - NID_rc2_cfb64 NID = 39 - NID_rc2_ofb64 NID = 40 - NID_sha NID = 41 - NID_shaWithRSAEncryption NID = 42 - NID_des_ede_cbc NID = 43 - NID_des_ede3_cbc NID = 44 - NID_des_ofb64 NID = 45 - NID_idea_ofb64 NID = 46 - NID_pkcs9 NID = 47 - NID_pkcs9_emailAddress NID = 48 - NID_pkcs9_unstructuredName NID = 49 - NID_pkcs9_contentType NID = 50 - NID_pkcs9_messageDigest NID = 51 - NID_pkcs9_signingTime NID = 52 - NID_pkcs9_countersignature NID = 53 - NID_pkcs9_challengePassword NID = 54 - NID_pkcs9_unstructuredAddress NID = 55 - NID_pkcs9_extCertAttributes NID = 56 - NID_netscape NID = 57 - NID_netscape_cert_extension NID = 58 - NID_netscape_data_type NID = 59 - NID_des_ede_cfb64 NID = 60 - NID_des_ede3_cfb64 NID = 61 - NID_des_ede_ofb64 NID = 62 - NID_des_ede3_ofb64 NID = 63 - NID_sha1 NID = 64 - NID_sha1WithRSAEncryption NID = 65 - NID_dsaWithSHA NID = 66 - NID_dsa_2 NID = 67 - NID_pbeWithSHA1AndRC2_CBC NID = 68 - NID_id_pbkdf2 NID = 69 - NID_dsaWithSHA1_2 NID = 70 - NID_netscape_cert_type NID = 71 - NID_netscape_base_url NID = 72 - NID_netscape_revocation_url NID = 73 - NID_netscape_ca_revocation_url NID = 74 - NID_netscape_renewal_url NID = 75 - NID_netscape_ca_policy_url NID = 76 - NID_netscape_ssl_server_name NID = 77 - NID_netscape_comment NID = 78 - NID_netscape_cert_sequence NID = 79 - NID_desx_cbc NID = 80 - NID_id_ce NID = 81 - NID_subject_key_identifier NID = 82 - NID_key_usage NID = 83 - NID_private_key_usage_period NID = 84 - NID_subject_alt_name NID = 85 - NID_issuer_alt_name NID = 86 - NID_basic_constraints NID = 87 - NID_crl_number NID = 88 - NID_certificate_policies NID = 89 - NID_authority_key_identifier NID = 90 - NID_bf_cbc NID = 91 - NID_bf_ecb NID = 92 - NID_bf_cfb64 NID = 93 - NID_bf_ofb64 NID = 94 - NID_mdc2 NID = 95 - NID_mdc2WithRSA NID = 96 - NID_rc4_40 NID = 97 - NID_rc2_40_cbc NID = 98 - NID_givenName NID = 99 - NID_surname NID = 100 - NID_initials NID = 101 - NID_uniqueIdentifier NID = 102 - NID_crl_distribution_points NID = 103 - NID_md5WithRSA NID = 104 - NID_serialNumber NID = 105 - NID_title NID = 106 - NID_description NID = 107 - NID_cast5_cbc NID = 108 - NID_cast5_ecb NID = 109 - NID_cast5_cfb64 NID = 110 - NID_cast5_ofb64 NID = 111 - NID_pbeWithMD5AndCast5_CBC NID = 112 - NID_dsaWithSHA1 NID = 113 - NID_md5_sha1 NID = 114 - NID_sha1WithRSA NID = 115 - NID_dsa NID = 116 - NID_ripemd160 NID = 117 - NID_ripemd160WithRSA NID = 119 - NID_rc5_cbc NID = 120 - NID_rc5_ecb NID = 121 - NID_rc5_cfb64 NID = 122 - NID_rc5_ofb64 NID = 123 - NID_rle_compression NID = 124 - NID_zlib_compression NID = 125 - NID_ext_key_usage NID = 126 - NID_id_pkix NID = 127 - NID_id_kp NID = 128 - NID_server_auth NID = 129 - NID_client_auth NID = 130 - NID_code_sign NID = 131 - NID_email_protect NID = 132 - NID_time_stamp NID = 133 - NID_ms_code_ind NID = 134 - NID_ms_code_com NID = 135 - NID_ms_ctl_sign NID = 136 - NID_ms_sgc NID = 137 - NID_ms_efs NID = 138 - NID_ns_sgc NID = 139 - NID_delta_crl NID = 140 - NID_crl_reason NID = 141 - NID_invalidity_date NID = 142 - NID_sxnet NID = 143 - NID_pbe_WithSHA1And128BitRC4 NID = 144 - NID_pbe_WithSHA1And40BitRC4 NID = 145 - NID_pbe_WithSHA1And3_Key_TripleDES_CBC NID = 146 - NID_pbe_WithSHA1And2_Key_TripleDES_CBC NID = 147 - NID_pbe_WithSHA1And128BitRC2_CBC NID = 148 - NID_pbe_WithSHA1And40BitRC2_CBC NID = 149 - NID_keyBag NID = 150 - NID_pkcs8ShroudedKeyBag NID = 151 - NID_certBag NID = 152 - NID_crlBag NID = 153 - NID_secretBag NID = 154 - NID_safeContentsBag NID = 155 - NID_friendlyName NID = 156 - NID_localKeyID NID = 157 - NID_x509Certificate NID = 158 - NID_sdsiCertificate NID = 159 - NID_x509Crl NID = 160 - NID_pbes2 NID = 161 - NID_pbmac1 NID = 162 - NID_hmacWithSHA1 NID = 163 - NID_id_qt_cps NID = 164 - NID_id_qt_unotice NID = 165 - NID_rc2_64_cbc NID = 166 - NID_SMIMECapabilities NID = 167 - NID_pbeWithMD2AndRC2_CBC NID = 168 - NID_pbeWithMD5AndRC2_CBC NID = 169 - NID_pbeWithSHA1AndDES_CBC NID = 170 - NID_ms_ext_req NID = 171 - NID_ext_req NID = 172 - NID_name NID = 173 - NID_dnQualifier NID = 174 - NID_id_pe NID = 175 - NID_id_ad NID = 176 - NID_info_access NID = 177 - NID_ad_OCSP NID = 178 - NID_ad_ca_issuers NID = 179 - NID_OCSP_sign NID = 180 - NID_X9_62_id_ecPublicKey NID = 408 - NID_hmac NID = 855 - NID_cmac NID = 894 - NID_dhpublicnumber NID = 920 - NID_tls1_prf NID = 1021 - NID_hkdf NID = 1036 - NID_X25519 NID = 1034 - NID_X448 NID = 1035 - NID_ED25519 NID = 1087 - NID_ED448 NID = 1088 - NID_sm2 NID = 1172 + NidUndef NID = 0 + NidRsadsi NID = 1 + NidPkcs NID = 2 + NidMd2 NID = 3 + NidMd5 NID = 4 + NidRc4 NID = 5 + NidRsaEncryption NID = 6 + NidMd2WithRSAEncryption NID = 7 + NidMd5WithRSAEncryption NID = 8 + NidPbeWithMD2AndDESCBC NID = 9 + NidPbeWithMD5AndDESCBC NID = 10 + NidX500 NID = 11 + NidX509 NID = 12 + NidCommonName NID = 13 + NidCountryName NID = 14 + NidLocalityName NID = 15 + NidStateOrProvinceName NID = 16 + NidOrganizationName NID = 17 + NidOrganizationalUnitName NID = 18 + NidRsa NID = 19 + NidPkcs7 NID = 20 + NidPkcs7Data NID = 21 + NidPkcs7Signed NID = 22 + NidPkcs7Enveloped NID = 23 + NidPkcs7SignedAndEnveloped NID = 24 + NidPkcs7Digest NID = 25 + NidPkcs7Encrypted NID = 26 + NidPkcs3 NID = 27 + NidDhKeyAgreement NID = 28 + NidDesEcb NID = 29 + NidDesCfb64 NID = 30 + NidDesCbc NID = 31 + NidDesEde NID = 32 + NidDesEde3 NID = 33 + NidIdeaCbc NID = 34 + NidIdeaCfb64 NID = 35 + NidIdeaEcb NID = 36 + NidRc2Cbc NID = 37 + NidRc2Ecb NID = 38 + NidRc2Cfb64 NID = 39 + NidRc2Ofb64 NID = 40 + NidSha NID = 41 + NidShaWithRSAEncryption NID = 42 + NidDesEdeCbc NID = 43 + NidDesEde3Cbc NID = 44 + NidDesOfb64 NID = 45 + NidIdeaOfb64 NID = 46 + NidPkcs9 NID = 47 + NidPkcs9EmailAddress NID = 48 + NidPkcs9UnstructuredName NID = 49 + NidPkcs9ContentType NID = 50 + NidPkcs9MessageDigest NID = 51 + NidPkcs9SigningTime NID = 52 + NidPkcs9Countersignature NID = 53 + NidPkcs9ChallengePassword NID = 54 + NidPkcs9UnstructuredAddress NID = 55 + NidPkcs9ExtCertAttributes NID = 56 + NidNetscape NID = 57 + NidNetscapeCertExtension NID = 58 + NidNetscapeDataType NID = 59 + NidDesEdeCfb64 NID = 60 + NidDesEde3Cfb64 NID = 61 + NidDesEdeOfb64 NID = 62 + NidDesEde3Ofb64 NID = 63 + NidSha1 NID = 64 + NidSha1WithRSAEncryption NID = 65 + NidDsaWithSHA NID = 66 + NidDsa2 NID = 67 + NidPbeWithSHA1AndRC2CBC NID = 68 + NidIDPbkdf2 NID = 69 + NidDsaWithSHA12 NID = 70 + NidNetscapeCertType NID = 71 + NidNetscapeBaseURL NID = 72 + NidNetscapeRevocationURL NID = 73 + NidNetscapeCaRevocationURL NID = 74 + NidNetscapeRenewalURL NID = 75 + NidNetscapeCaPolicyURL NID = 76 + NidNetscapeSslServerName NID = 77 + NidNetscapeComment NID = 78 + NidNetscapeCertSequence NID = 79 + NidDesxCbc NID = 80 + NidIDCe NID = 81 + NidSubjectKeyIdentifier NID = 82 + NidKeyUsage NID = 83 + NidPrivateKeyUsagePeriod NID = 84 + NidSubjectAltName NID = 85 + NidIssuerAltName NID = 86 + NidBasicConstraints NID = 87 + NidCrlNumber NID = 88 + NidCertificatePolicies NID = 89 + NidAuthorityKeyIdentifier NID = 90 + NidBfCbc NID = 91 + NidBfEcb NID = 92 + NidBfCfb64 NID = 93 + NidBfOfb64 NID = 94 + NidMdc2 NID = 95 + NidMdc2WithRSA NID = 96 + NidRc440 NID = 97 + NidRc240Cbc NID = 98 + NidGivenName NID = 99 + NidSurname NID = 100 + NidInitials NID = 101 + NidUniqueIdentifier NID = 102 + NidCrlDistributionPoints NID = 103 + NidMd5WithRSA NID = 104 + NidSerialNumber NID = 105 + NidTitle NID = 106 + NidDescription NID = 107 + NidCast5Cbc NID = 108 + NidCast5Ecb NID = 109 + NidCast5Cfb64 NID = 110 + NidCast5Ofb64 NID = 111 + NidPbeWithMD5AndCast5CBC NID = 112 + NidDsaWithSHA1 NID = 113 + NidMd5Sha1 NID = 114 + NidSha1WithRSA NID = 115 + NidDsa NID = 116 + NidRipemd160 NID = 117 + NidRipemd160WithRSA NID = 119 + NidRc5Cbc NID = 120 + NidRc5Ecb NID = 121 + NidRc5Cfb64 NID = 122 + NidRc5Ofb64 NID = 123 + NidRleCompression NID = 124 + NidZlibCompression NID = 125 + NidExtKeyUsage NID = 126 + NidIDPkix NID = 127 + NidIDKp NID = 128 + NidServerAuth NID = 129 + NidClientAuth NID = 130 + NidCodeSign NID = 131 + NidEmailProtect NID = 132 + NidTimeStamp NID = 133 + NidMsCodeInd NID = 134 + NidMsCodeCom NID = 135 + NidMsCtlSign NID = 136 + NidMsSgc NID = 137 + NidMsEfs NID = 138 + NidNsSgc NID = 139 + NidDeltaCrl NID = 140 + NidCrlReason NID = 141 + NidInvalidityDate NID = 142 + NidSxnet NID = 143 + NidPbeWithSHA1And128BitRC4 NID = 144 + NidPbeWithSHA1And40BitRC4 NID = 145 + NidPbeWithSHA1And3KeyTripleDESCBC NID = 146 + NidPbeWithSHA1And2KeyTripleDESCBC NID = 147 + NidPbeWithSHA1And128BitRC2CBC NID = 148 + NidPbeWithSHA1And40BitRC2CBC NID = 149 + NidKeyBag NID = 150 + NidPkcs8ShroudedKeyBag NID = 151 + NidCertBag NID = 152 + NidCrlBag NID = 153 + NidSecretBag NID = 154 + NidSafeContentsBag NID = 155 + NidFriendlyName NID = 156 + NidLocalKeyID NID = 157 + NidX509Certificate NID = 158 + NidSdsiCertificate NID = 159 + NidX509Crl NID = 160 + NidPbes2 NID = 161 + NidPbmac1 NID = 162 + NidHmacWithSHA1 NID = 163 + NidIDQtCps NID = 164 + NidIDQtUnotice NID = 165 + NidRc264Cbc NID = 166 + NidSMIMECapabilities NID = 167 + NidPbeWithMD2AndRC2CBC NID = 168 + NidPbeWithMD5AndRC2CBC NID = 169 + NidPbeWithSHA1AndDESCBC NID = 170 + NidMsExtReq NID = 171 + NidExtReq NID = 172 + NidName NID = 173 + NidDnQualifier NID = 174 + NidIDPe NID = 175 + NidIDAd NID = 176 + NidInfoAccess NID = 177 + NidAdOCSP NID = 178 + NidAdCaIssuers NID = 179 + NidOCSPSign NID = 180 + NidX962IdEcPublicKey NID = 408 + NidHmac NID = 855 + NidCmac NID = 894 + NidDhpublicnumber NID = 920 + NidTLS1Prf NID = 1021 + NidHkdf NID = 1036 + NidX25519 NID = 1034 + NidX448 NID = 1035 + NidEd25519 NID = 1087 + NidEd448 NID = 1088 + NidSM2 NID = 1172 ) diff --git a/crypto/sha1/sha1.go b/crypto/sha1/sha1.go index c71e56c..2820b84 100644 --- a/crypto/sha1/sha1.go +++ b/crypto/sha1/sha1.go @@ -15,81 +15,85 @@ package sha1 // #include "../shim.h" -// #cgo linux LDFLAGS: -lcrypto -// #cgo darwin LDFLAGS: -lcrypto -// #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto import "C" import ( - "errors" + "fmt" "runtime" "unsafe" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) -type SHA1Hash struct { +const MDSize = 20 + +type SHA1 struct { ctx *C.EVP_MD_CTX engine *crypto.Engine } -func New() (*SHA1Hash, error) { return NewWithEngine(nil) } +func New() (*SHA1, error) { return NewWithEngine(nil) } -func NewWithEngine(e *crypto.Engine) (*SHA1Hash, error) { - hash := &SHA1Hash{engine: e} +func NewWithEngine(e *crypto.Engine) (*SHA1, error) { + hash := &SHA1{ctx: nil, engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { - return nil, errors.New("openssl: sha1: unable to allocate ctx") + return nil, fmt.Errorf("failed to create md ctx: %w", crypto.ErrMallocFailure) } - runtime.SetFinalizer(hash, func(hash *SHA1Hash) { hash.Close() }) + runtime.SetFinalizer(hash, func(hash *SHA1) { hash.Close() }) if err := hash.Reset(); err != nil { return nil, err } + return hash, nil } -func (s *SHA1Hash) Close() { +func (s *SHA1) Close() { if s.ctx != nil { C.X_EVP_MD_CTX_free(s.ctx) s.ctx = nil } } -func (s *SHA1Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha1(), (*C.ENGINE)(s.engine.Engine())) { - return errors.New("openssl: sha1: cannot init digest ctx") +func (s *SHA1) Reset() error { + if C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha1(), (*C.ENGINE)(s.engine.Engine())) != 1 { + return fmt.Errorf("failed to init digest ctx %w", crypto.PopError()) } + return nil } -func (s *SHA1Hash) Write(p []byte) (n int, err error) { - if len(p) == 0 { +func (s *SHA1) Write(data []byte) (int, error) { + if len(data) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), - C.size_t(len(p))) { - return 0, errors.New("openssl: sha1: cannot update digest") + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) != 1 { + return 0, fmt.Errorf("failed to update digest: %w", crypto.PopError()) } - return len(p), nil + + return len(data), nil } -func (s *SHA1Hash) Sum() (result [20]byte, err error) { - if 1 != C.X_EVP_DigestFinal_ex(s.ctx, - (*C.uchar)(unsafe.Pointer(&result[0])), nil) { - return result, errors.New("openssl: sha1: cannot finalize ctx") +func (s *SHA1) Sum() ([MDSize]byte, error) { + var result [MDSize]byte + + if C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) != 1 { + return result, fmt.Errorf("failed to finalize digest: %w", crypto.PopError()) } + return result, s.Reset() } -func Sum(data []byte) (result [20]byte, err error) { +func Sum(data []byte) ([MDSize]byte, error) { hash, err := New() if err != nil { - return result, err + return [MDSize]byte{}, err } + defer hash.Close() + if _, err := hash.Write(data); err != nil { - return result, err + return [MDSize]byte{}, err } return hash.Sum() } diff --git a/crypto/sha1/sha1_test.go b/crypto/sha1/sha1_test.go index 0101239..2b458dc 100644 --- a/crypto/sha1/sha1_test.go +++ b/crypto/sha1/sha1_test.go @@ -12,16 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sha1 +package sha1_test import ( "crypto/rand" "crypto/sha1" "io" "testing" + + tsSHA1 "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sha1" ) func TestSHA1(t *testing.T) { + t.Parallel() + for i := 0; i < 100; i++ { buf := make([]byte, 10*1024-i) if _, err := io.ReadFull(rand.Reader, buf); err != nil { @@ -29,7 +33,8 @@ func TestSHA1(t *testing.T) { } expected := sha1.Sum(buf) - got, err := Sum(buf) + + got, err := tsSHA1.Sum(buf) if err != nil { t.Fatal(err) } @@ -41,18 +46,24 @@ func TestSHA1(t *testing.T) { } func TestSHA1Writer(t *testing.T) { - ohash, err := New() + t.Parallel() + + ohash, err := tsSHA1.New() if err != nil { t.Fatal(err) } + hash := sha1.New() for i := 0; i < 100; i++ { if err := ohash.Reset(); err != nil { t.Fatal(err) } + hash.Reset() + buf := make([]byte, 10*1024-i) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { t.Fatal(err) } @@ -60,6 +71,7 @@ func TestSHA1Writer(t *testing.T) { if _, err := ohash.Write(buf); err != nil { t.Fatal(err) } + if _, err := hash.Write(buf); err != nil { t.Fatal(err) } @@ -67,6 +79,7 @@ func TestSHA1Writer(t *testing.T) { var got, exp [20]byte hash.Sum(exp[:0]) + got, err := ohash.Sum() if err != nil { t.Fatal(err) @@ -81,19 +94,23 @@ func TestSHA1Writer(t *testing.T) { type shafunc func([]byte) func benchmarkSHA1(b *testing.B, length int64, fn shafunc) { + b.Helper() + buf := make([]byte, length) if _, err := io.ReadFull(rand.Reader, buf); err != nil { b.Fatal(err) } + b.SetBytes(length) b.ResetTimer() + for i := 0; i < b.N; i++ { fn(buf) } } func BenchmarkSHA1Large_openssl(b *testing.B) { - benchmarkSHA1(b, 1024*1024, func(buf []byte) { Sum(buf) }) + benchmarkSHA1(b, 1024*1024, func(buf []byte) { _, _ = tsSHA1.Sum(buf) }) } func BenchmarkSHA1Large_stdlib(b *testing.B) { @@ -101,7 +118,7 @@ func BenchmarkSHA1Large_stdlib(b *testing.B) { } func BenchmarkSHA1Small_openssl(b *testing.B) { - benchmarkSHA1(b, 1, func(buf []byte) { Sum(buf) }) + benchmarkSHA1(b, 1, func(buf []byte) { _, _ = tsSHA1.Sum(buf) }) } func BenchmarkSHA1Small_stdlib(b *testing.B) { diff --git a/crypto/sha256/sha256.go b/crypto/sha256/sha256.go index 37fe62f..c11447b 100644 --- a/crypto/sha256/sha256.go +++ b/crypto/sha256/sha256.go @@ -15,14 +15,10 @@ package sha256 // #include "../shim.h" -// #cgo linux LDFLAGS: -lcrypto -// #cgo darwin LDFLAGS: -lcrypto -// #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto import "C" import ( - "errors" + "fmt" "runtime" "unsafe" @@ -37,15 +33,16 @@ type SHA256 struct { func New() (*SHA256, error) { return NewWithEngine(nil) } func NewWithEngine(e *crypto.Engine) (*SHA256, error) { - hash := &SHA256{engine: e} + hash := &SHA256{ctx: nil, engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { - return nil, errors.New("openssl: sha256: unable to allocate ctx") + return nil, fmt.Errorf("failed to create md ctx %w", crypto.ErrMallocFailure) } runtime.SetFinalizer(hash, func(hash *SHA256) { hash.Close() }) if err := hash.Reset(); err != nil { return nil, err } + return hash, nil } @@ -57,39 +54,45 @@ func (s *SHA256) Close() { } func (s *SHA256) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha256(), (*C.ENGINE)(s.engine.Engine())) { - return errors.New("openssl: sha256: cannot init digest ctx") + if C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha256(), (*C.ENGINE)(s.engine.Engine())) != 1 { + return fmt.Errorf("failed to init digest ctx: %w", crypto.PopError()) } + return nil } -func (s *SHA256) Write(p []byte) (n int, err error) { - if len(p) == 0 { +func (s *SHA256) Write(data []byte) (int, error) { + if len(data) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), - C.size_t(len(p))) { - return 0, errors.New("openssl: sha256: cannot update digest") + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) != 1 { + return 0, fmt.Errorf("failed to update digest: %w", crypto.PopError()) } - return len(p), nil + + return len(data), nil } -func (s *SHA256) Sum() (result [32]byte, err error) { - if 1 != C.X_EVP_DigestFinal_ex(s.ctx, - (*C.uchar)(unsafe.Pointer(&result[0])), nil) { - return result, errors.New("openssl: sha256: cannot finalize ctx") +func (s *SHA256) Sum() ([32]byte, error) { + var result [32]byte + + if C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) != 1 { + return result, fmt.Errorf("failed to finalize digest: %w", crypto.PopError()) } + return result, s.Reset() } -func Sum(data []byte) (result [32]byte, err error) { +func Sum(data []byte) ([32]byte, error) { hash, err := New() if err != nil { - return result, err + return [32]byte{}, err } + defer hash.Close() + if _, err := hash.Write(data); err != nil { - return result, err + return [32]byte{}, err } + return hash.Sum() } diff --git a/crypto/sha256/sha256_test.go b/crypto/sha256/sha256_test.go index e5f5d9b..4f7859f 100644 --- a/crypto/sha256/sha256_test.go +++ b/crypto/sha256/sha256_test.go @@ -12,16 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sha256 +package sha256_test import ( "crypto/rand" "crypto/sha256" "io" "testing" + + tsSHA256 "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sha256" ) func Test(t *testing.T) { + t.Parallel() + for i := 0; i < 100; i++ { buf := make([]byte, 10*1024-i) if _, err := io.ReadFull(rand.Reader, buf); err != nil { @@ -29,7 +33,8 @@ func Test(t *testing.T) { } expected := sha256.Sum256(buf) - got, err := Sum(buf) + + got, err := tsSHA256.Sum(buf) if err != nil { t.Fatal(err) } @@ -41,17 +46,22 @@ func Test(t *testing.T) { } func TestWriter(t *testing.T) { - ohash, err := New() + t.Parallel() + + ohash, err := tsSHA256.New() if err != nil { t.Fatal(err) } + hash := sha256.New() for i := 0; i < 100; i++ { if err := ohash.Reset(); err != nil { t.Fatal(err) } + hash.Reset() + buf := make([]byte, 10*1024-i) if _, err := io.ReadFull(rand.Reader, buf); err != nil { t.Fatal(err) @@ -60,6 +70,7 @@ func TestWriter(t *testing.T) { if _, err := ohash.Write(buf); err != nil { t.Fatal(err) } + if _, err := hash.Write(buf); err != nil { t.Fatal(err) } @@ -67,6 +78,7 @@ func TestWriter(t *testing.T) { var got, exp [32]byte hash.Sum(exp[:0]) + got, err := ohash.Sum() if err != nil { t.Fatal(err) @@ -79,19 +91,23 @@ func TestWriter(t *testing.T) { } func benchmark(b *testing.B, length int64, fn func([]byte)) { + b.Helper() + buf := make([]byte, length) if _, err := io.ReadFull(rand.Reader, buf); err != nil { b.Fatal(err) } + b.SetBytes(length) b.ResetTimer() + for i := 0; i < b.N; i++ { fn(buf) } } func BenchmarkLarge_openssl(b *testing.B) { - benchmark(b, 1024*1024, func(buf []byte) { Sum(buf) }) + benchmark(b, 1024*1024, func(buf []byte) { _, _ = tsSHA256.Sum(buf) }) } func BenchmarkLarge_stdlib(b *testing.B) { @@ -99,7 +115,7 @@ func BenchmarkLarge_stdlib(b *testing.B) { } func BenchmarkSmall_openssl(b *testing.B) { - benchmark(b, 1, func(buf []byte) { Sum(buf) }) + benchmark(b, 1, func(buf []byte) { _, _ = tsSHA256.Sum(buf) }) } func BenchmarkSmall_stdlib(b *testing.B) { diff --git a/crypto/sm2/sm2.go b/crypto/sm2/sm2.go index 10f09fa..7805f17 100644 --- a/crypto/sm2/sm2.go +++ b/crypto/sm2/sm2.go @@ -8,14 +8,10 @@ package sm2 // #include "../shim.h" -// #cgo linux LDFLAGS: -lcrypto -// #cgo darwin LDFLAGS: -lcrypto -// #cgo windows CFLAGS: -Wall -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto import "C" import ( - "errors" + "fmt" "math/big" "unsafe" @@ -24,27 +20,37 @@ import ( // VerifyASN1 verifies ASN.1 encoded signature. Returns nil on success. func VerifyASN1(pub crypto.PublicKey, data, sig []byte) error { - if pub.KeyType() != crypto.NID_sm2 { - return errors.New("SM2: key type is not sm2") + if pub.KeyType() != crypto.NidSM2 { + return fmt.Errorf("key type is not sm2: %w", crypto.ErrWrongKeyType) } - return pub.VerifyPKCS1v15(crypto.SM3_Method, data, sig) + err := pub.VerifyPKCS1v15(crypto.SM3Method(), data, sig) + if err != nil { + return fmt.Errorf("failed to verify: %w", err) + } + + return nil } // SignASN1 signs the data with priv and returns ASN.1 encoded signature. func SignASN1(priv crypto.PrivateKey, data []byte) ([]byte, error) { - if priv.KeyType() != crypto.NID_sm2 { - return nil, errors.New("SM2: key type is not sm2") + if priv.KeyType() != crypto.NidSM2 { + return nil, fmt.Errorf("key type is not sm2: %w", crypto.ErrWrongKeyType) + } + + ret, err := priv.SignPKCS1v15(crypto.SM3Method(), data) + if err != nil { + return nil, fmt.Errorf("failed to sign: %w", err) } - return priv.SignPKCS1v15(crypto.SM3_Method, data) + return ret, nil } // Verify verifies the signature in r, s of data using the public key, pub. // Returns nil on success. func Verify(pub crypto.PublicKey, data []byte, r, s *big.Int) error { - if pub.KeyType() != crypto.NID_sm2 { - return errors.New("SM2: key type is not sm2") + if pub.KeyType() != crypto.NidSM2 { + return fmt.Errorf("key type is not sm2 %w", crypto.ErrWrongKeyType) } sm2Sig := C.ECDSA_SIG_new() @@ -55,29 +61,29 @@ func Verify(pub crypto.PublicKey, data []byte, r, s *big.Int) error { ret := C.ECDSA_SIG_set0(sm2Sig, rBig, sBig) if ret != 1 { - return errors.New("SM2: set r,s failed") + return fmt.Errorf("failed to set r/s: %w", crypto.ErrNilParameter) } - len := C.i2d_ECDSA_SIG(sm2Sig, nil) + len1 := C.i2d_ECDSA_SIG(sm2Sig, nil) - buf := (*C.uchar)(C.malloc(C.size_t(len))) + buf := (*C.uchar)(C.malloc(C.size_t(len1))) defer C.free(unsafe.Pointer(buf)) tmp := buf len2 := C.i2d_ECDSA_SIG(sm2Sig, &tmp) - return VerifyASN1(pub, data, C.GoBytes(unsafe.Pointer(buf), C.int(len2))) + return VerifyASN1(pub, data, C.GoBytes(unsafe.Pointer(buf), len2)) } // Sign signs the data with the private key, priv. -func Sign(priv crypto.PrivateKey, data []byte) (r, s *big.Int, err error) { - if priv.KeyType() != crypto.NID_sm2 { - return nil, nil, errors.New("SM2: key type is not sm2") +func Sign(priv crypto.PrivateKey, data []byte) (*big.Int, *big.Int, error) { + if priv.KeyType() != crypto.NidSM2 { + return nil, nil, fmt.Errorf("key type is not sm2: %w", crypto.ErrWrongKeyType) } - sig, err := priv.SignPKCS1v15(crypto.SM3_Method, data) + sig, err := priv.SignPKCS1v15(crypto.SM3Method(), data) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to sign data: %w", err) } buf := (*C.uchar)(C.malloc(C.size_t(len(sig)))) @@ -86,7 +92,7 @@ func Sign(priv crypto.PrivateKey, data []byte) (r, s *big.Int, err error) { sm2Sig := C.d2i_ECDSA_SIG(nil, &buf, C.long(len(sig))) if sm2Sig == nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to decode signature: %w", err) } defer C.ECDSA_SIG_free(sm2Sig) @@ -99,17 +105,17 @@ func Sign(priv crypto.PrivateKey, data []byte) (r, s *big.Int, err error) { rLen := C.BN_bn2bin(rBig, (*C.uchar)(unsafe.Pointer(&rBytes[0]))) sLen := C.BN_bn2bin(sBig, (*C.uchar)(unsafe.Pointer(&sBytes[0]))) - r = new(big.Int).SetBytes(rBytes[:rLen]) - s = new(big.Int).SetBytes(sBytes[:sLen]) + r := new(big.Int).SetBytes(rBytes[:rLen]) + s := new(big.Int).SetBytes(sBytes[:sLen]) return r, s, nil } // GenerateKey generates a new SM2 key pair. func GenerateKey() (crypto.PrivateKey, error) { - priv, err := crypto.GenerateECKey(crypto.Sm2Curve) + priv, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create key: %w", err) } return priv, nil @@ -117,18 +123,28 @@ func GenerateKey() (crypto.PrivateKey, error) { // Encrypt encrypts the data with the public key, publ. func Encrypt(pub crypto.PublicKey, data []byte) ([]byte, error) { - if pub.KeyType() != crypto.NID_sm2 { - return nil, errors.New("SM2: key type is not sm2") + if pub.KeyType() != crypto.NidSM2 { + return nil, fmt.Errorf("key type is not sm2: %w", crypto.ErrWrongKeyType) } - return pub.Encrypt(data) + ret, err := pub.Encrypt(data) + if err != nil { + return nil, fmt.Errorf("failed to encrypt: %w", err) + } + + return ret, nil } // Decrypt decrypts the ciphertext with the private key, priv. func Decrypt(priv crypto.PrivateKey, data []byte) ([]byte, error) { - if priv.KeyType() != crypto.NID_sm2 { - return nil, errors.New("SM2: key type is not sm2") + if priv.KeyType() != crypto.NidSM2 { + return nil, fmt.Errorf("key type is not sm2: %w", crypto.ErrWrongKeyType) + } + + ret, err := priv.Decrypt(data) + if err != nil { + return nil, fmt.Errorf("failed to decrypt: %w", err) } - return priv.Decrypt(data) + return ret, nil } diff --git a/crypto/sm2/sm2_test.go b/crypto/sm2/sm2_test.go index 7b0cef7..189a769 100644 --- a/crypto/sm2/sm2_test.go +++ b/crypto/sm2/sm2_test.go @@ -5,58 +5,71 @@ // in the file LICENSE in the source distribution or at // https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE -package sm2 +package sm2_test import ( "encoding/hex" "math/big" + "strings" "testing" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" + "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sm2" ) -var sm2_key1 = []byte(`-----BEGIN PRIVATE KEY----- +const sm2Key1 = `-----BEGIN PRIVATE KEY----- MIGHAgEAMBMGByqGSM49AgEGCCqBHM9VAYItBG0wawIBAQQg0JFWczAXva2An9m7 2MaT9gIwWTFptvlKrxyO4TjMmbWhRANCAAQ5OirZ4n5DrKqrhaGdO4VZHhRAYVcX Wt3Te/d/8Mr57Tf886i09VwDhSMmH8pmNq/mp6+ioUgqYG9cs6GLLioe -----END PRIVATE KEY----- -`) +` -var sm2_pubkey1 = []byte(`-----BEGIN PUBLIC KEY----- +const sm2Pubkey1 = `-----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAEOToq2eJ+Q6yqq4WhnTuFWR4UQGFX F1rd03v3f/DK+e03/POotPVcA4UjJh/KZjav5qevoqFIKmBvXLOhiy4qHg== -----END PUBLIC KEY----- -`) +` func TestSM2PublicKeyVerifyASN1(t *testing.T) { - pub, err := crypto.LoadPublicKeyFromPEM(sm2_pubkey1) + t.Parallel() + + pub, err := crypto.LoadPublicKeyFromPEM([]byte(sm2Pubkey1)) if err != nil { t.Fatal(err) } data := []byte("hello world") - sig, _ := hex.DecodeString("3046022100ba37b776135afbf5bf36b21f4a65889bcd0037092be47f6429f877790b8cb9c402210097b59fd56d41317d490dd300e7e69d7909a0885414ac3b2c5a24bdfc1588cb55") + hexSig := `3046022100ba37b776135afbf5bf36b21f4a65889bcd0037092be47f6429f877790b8cb9c402210097b59fd56d41317d490dd300e +7e69d7909a0885414ac3b2c5a24bdfc1588cb55` + sig, _ := hex.DecodeString(strings.ReplaceAll(hexSig, "\n", "")) - if VerifyASN1(pub, data, sig) != nil { + if sm2.VerifyASN1(pub, data, sig) != nil { t.Fatal() } } func TestSM2PrivateKey2PublicVerifyASN1(t *testing.T) { - priv, err := crypto.LoadPrivateKeyFromPEM(sm2_key1) + t.Parallel() + + priv, err := crypto.LoadPrivateKeyFromPEM([]byte(sm2Key1)) if err != nil { t.Fatal(err) } data := []byte("hello world") - sig, _ := hex.DecodeString("3046022100ba37b776135afbf5bf36b21f4a65889bcd0037092be47f6429f877790b8cb9c402210097b59fd56d41317d490dd300e7e69d7909a0885414ac3b2c5a24bdfc1588cb55") - if VerifyASN1(priv.Public(), data, sig) != nil { + hexSig := `3046022100ba37b776135afbf5bf36b21f4a65889bcd0037092be47f6429f877790b8cb9c402210097b59fd56d41317d490dd300e +7e69d7909a0885414ac3b2c5a24bdfc1588cb55` + sig, _ := hex.DecodeString(strings.ReplaceAll(hexSig, "\n", "")) + + if sm2.VerifyASN1(priv.Public(), data, sig) != nil { t.Fatal() } } func TestSM2PrivateKey2PublicVerify(t *testing.T) { - priv, err := crypto.LoadPrivateKeyFromPEM(sm2_key1) + t.Parallel() + + priv, err := crypto.LoadPrivateKeyFromPEM([]byte(sm2Key1)) if err != nil { t.Fatal(err) } @@ -67,84 +80,97 @@ func TestSM2PrivateKey2PublicVerify(t *testing.T) { r := new(big.Int).SetBytes(rBytes) s := new(big.Int).SetBytes(sBytes) - if Verify(priv.Public(), data, r, s) != nil { + if sm2.Verify(priv.Public(), data, r, s) != nil { t.Fatal() } } func TestSM2VerifySignASN1(t *testing.T) { - priv, err := crypto.LoadPrivateKeyFromPEM(sm2_key1) + t.Parallel() + + priv, err := crypto.LoadPrivateKeyFromPEM([]byte(sm2Key1)) if err != nil { t.Fatal(err) } data := []byte("hello world") - sig, err := SignASN1(priv, data) + + sig, err := sm2.SignASN1(priv, data) if err != nil { t.Fatal(err) } - if VerifyASN1(priv.Public(), data, sig) != nil { + if sm2.VerifyASN1(priv.Public(), data, sig) != nil { t.Fatal() } } func TestNewSM2VerifySignASN1(t *testing.T) { - priv, err := GenerateKey() + t.Parallel() + + priv, err := sm2.GenerateKey() if err != nil { t.Fatal(err) } pub := priv.Public() - data := []byte("hello world") - sig, err := SignASN1(priv, data) + + sig, err := sm2.SignASN1(priv, data) if err != nil { t.Fatal(err) } - if VerifyASN1(pub, data, sig) != nil { + if sm2.VerifyASN1(pub, data, sig) != nil { t.Fatal() } } func TestSM2VerifySign(t *testing.T) { - priv, err := crypto.LoadPrivateKeyFromPEM(sm2_key1) + t.Parallel() + + priv, err := crypto.LoadPrivateKeyFromPEM([]byte(sm2Key1)) if err != nil { t.Fatal(err) } data := []byte("hello world") - r, s, err := Sign(priv, data) + + r, s, err := sm2.Sign(priv, data) if err != nil { t.Fatal(err) } - if Verify(priv.Public(), data, r, s) != nil { + if sm2.Verify(priv.Public(), data, r, s) != nil { t.Fatal() } } func TestNewSM2VerifySign(t *testing.T) { - priv, err := GenerateKey() + t.Parallel() + + priv, err := sm2.GenerateKey() if err != nil { t.Fatal(err) } pub := priv.Public() data := []byte("hello world") - r, s, err := Sign(priv, data) + + r, s, err := sm2.Sign(priv, data) if err != nil { t.Fatal(err) } - if Verify(pub, data, r, s) != nil { + if sm2.Verify(pub, data, r, s) != nil { t.Fatal() } } func TestSM2DecryptEncrypt(t *testing.T) { - priv, err := crypto.LoadPrivateKeyFromPEM(sm2_key1) + t.Parallel() + + priv, err := crypto.LoadPrivateKeyFromPEM([]byte(sm2Key1)) if err != nil { t.Fatal(err) } @@ -152,12 +178,12 @@ func TestSM2DecryptEncrypt(t *testing.T) { pub := priv.Public() data := []byte("hello world") - ciphertext, err := Encrypt(pub, data) + ciphertext, err := sm2.Encrypt(pub, data) if err != nil { t.Fatal(err) } - plaintext, err := Decrypt(priv, ciphertext) + plaintext, err := sm2.Decrypt(priv, ciphertext) if err != nil { t.Fatal(err) } @@ -168,7 +194,9 @@ func TestSM2DecryptEncrypt(t *testing.T) { } func TestNewSM2DecryptEncrypt(t *testing.T) { - priv, err := GenerateKey() + t.Parallel() + + priv, err := sm2.GenerateKey() if err != nil { t.Fatal(err) } @@ -176,12 +204,12 @@ func TestNewSM2DecryptEncrypt(t *testing.T) { pub := priv.Public() data := []byte("hello world") - ciphertext, err := Encrypt(pub, data) + ciphertext, err := sm2.Encrypt(pub, data) if err != nil { t.Fatal(err) } - plaintext, err := Decrypt(priv, ciphertext) + plaintext, err := sm2.Decrypt(priv, ciphertext) if err != nil { t.Fatal(err) } diff --git a/crypto/sm3/sm3.go b/crypto/sm3/sm3.go index e750691..902a055 100644 --- a/crypto/sm3/sm3.go +++ b/crypto/sm3/sm3.go @@ -8,14 +8,10 @@ package sm3 // #include "../shim.h" -// #cgo linux LDFLAGS: -lcrypto -// #cgo darwin LDFLAGS: -lcrypto -// #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto import "C" import ( - "errors" + "fmt" "hash" "runtime" "unsafe" @@ -24,8 +20,8 @@ import ( ) const ( - SM3_DIGEST_LENGTH = 32 - SM3_CBLOCK = 64 + MDSize = 32 + sm3Cblock = 64 ) var _ hash.Hash = new(SM3) @@ -38,30 +34,32 @@ type SM3 struct { func New() (*SM3, error) { return NewWithEngine(nil) } func NewWithEngine(e *crypto.Engine) (*SM3, error) { - h, err := newWithEngine(e) + hash, err := newWithEngine(e) if err != nil { return nil, err } - h.Reset() - return h, nil + hash.Reset() + + return hash, nil } func newWithEngine(e *crypto.Engine) (*SM3, error) { - hash := &SM3{engine: e} + hash := &SM3{ctx: nil, engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { - return nil, errors.New("openssl: sm3: unable to allocate ctx") + return nil, fmt.Errorf("failed to create md ctx: %w", crypto.ErrMallocFailure) } runtime.SetFinalizer(hash, func(hash *SM3) { hash.Close() }) + return hash, nil } func (s *SM3) BlockSize() int { - return SM3_CBLOCK + return sm3Cblock } func (s *SM3) Size() int { - return SM3_DIGEST_LENGTH + return MDSize } func (s *SM3) Close() { @@ -75,14 +73,14 @@ func (s *SM3) Reset() { C.X_EVP_DigestInit_ex(s.ctx, C.EVP_sm3(), (*C.ENGINE)(s.engine.Engine())) } -func (s *SM3) Write(p []byte) (n int, err error) { - if len(p) == 0 { +func (s *SM3) Write(data []byte) (int, error) { + if len(data) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), C.size_t(len(p))) { - return 0, errors.New("openssl: sm3: cannot update digest") + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&data[0]), C.size_t(len(data))) != 1 { + return 0, fmt.Errorf("failed to update digest: %w", crypto.PopError()) } - return len(p), nil + return len(data), nil } func (s *SM3) Sum(in []byte) []byte { @@ -99,19 +97,19 @@ func (s *SM3) Sum(in []byte) []byte { return append(in, result[:]...) } -func (s *SM3) checkSum() (result [SM3_DIGEST_LENGTH]byte) { +func (s *SM3) checkSum() [MDSize]byte { + var result [MDSize]byte + C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) + return result } -func Sum(data []byte) (result [SM3_DIGEST_LENGTH]byte) { - C.X_EVP_Digest( - unsafe.Pointer(&data[0]), - C.size_t(len(data)), - (*C.uchar)(unsafe.Pointer(&result[0])), - nil, - C.EVP_sm3(), - nil, - ) - return +func Sum(data []byte) [MDSize]byte { + var result [MDSize]byte + + C.X_EVP_Digest(unsafe.Pointer(&data[0]), C.size_t(len(data)), (*C.uchar)(unsafe.Pointer(&result[0])), nil, + C.EVP_sm3(), nil) + + return result } diff --git a/crypto/sm3/sm3_test.go b/crypto/sm3/sm3_test.go index 575029a..4f3e75b 100644 --- a/crypto/sm3/sm3_test.go +++ b/crypto/sm3/sm3_test.go @@ -5,34 +5,82 @@ // in the file LICENSE in the source distribution or at // https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE -package sm3 +package sm3_test import ( "bytes" "crypto/rand" "encoding/hex" "io" + "strings" "testing" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sm3" ) func TestSM3(t *testing.T) { - var testData = []struct { + t.Parallel() + + testData := []struct { in string out string }{ - {"0090414C494345313233405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A20AE4C7798AA0F119471BEE11825BE46202BB79E2A5844495E97C04FF4DF2548A7C0240F88F1CD4E16352A73C17B7F16F07353E53A176D684A9FE0C6BB798E857", "F4A38489E32B45B6F876E3AC2168CA392362DC8F23459C1D1146FC3DBFB7BC9A"}, + { + `0090414C494345313233405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6 +D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD4 +3D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A20AE4C7798AA0F119471BEE11825BE46202BB79E2A5844495E97C04 +FF4DF2548A7C0240F88F1CD4E16352A73C17B7F16F07353E53A176D684A9FE0C6BB798E857`, + "F4A38489E32B45B6F876E3AC2168CA392362DC8F23459C1D1146FC3DBFB7BC9A", + }, {"616263", "66C7F0F462EEEDD9D1F2D46BDC10E4E24167C4875CF2F7A2297DA02B8F4BA8E0"}, - {"61626364616263646162636461626364616263646162636461626364616263646162636461626364616263646162636461626364616263646162636461626364", "DEBE9FF92275B8A138604889C18E5A4D6FDB70E5387E5765293dCbA39C0C5732"}, - {"0090414C494345313233405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A20AE4C7798AA0F119471BEE11825BE46202BB79E2A5844495E97C04FF4DF2548A7C0240F88F1CD4E16352A73C17B7F16F07353E53A176D684A9FE0C6BB798E857", "F4A38489E32B45B6F876E3AC2168CA392362DC8F23459C1D1146FC3DBFB7BC9A"}, - {"0090414C494345313233405941484F4F2E434F4D00000000000000000000000000000000000000000000000000000000000000000000E78BCD09746C202378A7E72B12BCE00266B9627ECB0B5A25367AD1AD4CC6242B00CDB9CA7F1E6B0441F658343F4B10297C0EF9B6491082400A62E7A7485735FADD013DE74DA65951C4D76DC89220D5F7777A611B1C38BAE260B175951DC8060C2B3E0165961645281A8626607B917F657D7E9382F1EA5CD931F40F6627F357542653B201686522130D590FB8DE635D8FCA715CC6BF3D05BEF3F75DA5D543454448166612", "26352AF82EC19F207BBC6F9474E11E90CE0F7DDACE03B27F801817E897A81FD5"}, - {"0090414C494345313233405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A23099093BF3C137D8FCBBCDF4A2AE50F3B0F216C3122D79425FE03A45DBFE16553DF79E8DAC1CF0ECBAA2F2B49D51A4B387F2EFAF482339086A27A8E05BAED98B", "E4D1D0C3CA4C7F11BC8FF8CB3F4C02A78F108FA098E51A668487240F75E20F31"}, - {"008842494C4C343536405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2245493D446C38D8CC0F118374690E7DF633A8A4BFB3329B5ECE604B2B4F37F4353C0869F4B9E17773DE68FEC45E14904E0DEA45BF6CECF9918C85EA047C60A4C", "6B4B6D0E276691BD4A11BF72F4FB501AE309FDACB72FA6CC336E6656119ABD67"}, - {"4D38D2958CA7FD2CFAE3AF04486959CF92C8EF48E8B83A05C112E739D5F181D03082020CA003020102020900AF28725D98D33143300C06082A811CCF550183750500307D310B300906035504060C02636E310B300906035504080C02626A310B300906035504070C02626A310F300D060355040A0C06746F70736563310F300D060355040B0C06746F707365633111300F06035504030C08546F707365634341311F301D06092A864886F70D0109010C10626A40746F707365632E636F6D2E636E301E170D3132303632343037353433395A170D3332303632303037353433395A307D310B300906035504060C02636E310B300906035504080C02626A310B300906035504070C02626A310F300D060355040A0C06746F70736563310F300D060355040B0C06746F707365633111300F06035504030C08546F707365634341311F301D06092A864886F70D0109010C10626A40746F707365632E636F6D2E636E3059301306072A8648CE3D020106082A811CCF5501822D03420004D69C2F1EEC3BFB6B95B30C28085C77B125D77A9C39525D8190768F37D6B205B589DCD316BBE7D89A9DC21917F17799E698531F5E6E3E10BD31370B259C3F81C3A3733071300F0603551D130101FF040530030101FF301D0603551D0E041604148E5D90347858BAAAD870D8BDFBA6A85E7B563B64301F0603551D230418301680148E5D90347858BAAAD870D8BDFBA6A85E7B563B64300B0603551D0F040403020106301106096086480186F8420101040403020057", "C3B02E500A8B60B77DEDCF6F4C11BEF8D56E5CDE708C72065654FD7B2167915A"}, + {`61626364616263646162636461626364616263646162636461626364616263646162636461626364616263646162636461626364616263 +646162636461626364`, "DEBE9FF92275B8A138604889C18E5A4D6FDB70E5387E5765293dCbA39C0C5732"}, + { + `0090414C494345313233405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6 +D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD4 +3D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A20AE4C7798AA0F119471BEE11825BE46202BB79E2A5844495E97C04 +FF4DF2548A7C0240F88F1CD4E16352A73C17B7F16F07353E53A176D684A9FE0C6BB798E857`, + "F4A38489E32B45B6F876E3AC2168CA392362DC8F23459C1D1146FC3DBFB7BC9A", + }, + { + `0090414C494345313233405941484F4F2E434F4D00000000000000000000000000000000000000000000000000000000000000000000E7 +8BCD09746C202378A7E72B12BCE00266B9627ECB0B5A25367AD1AD4CC6242B00CDB9CA7F1E6B0441F658343F4B10297C0EF9B6491082400A62E7A748 +5735FADD013DE74DA65951C4D76DC89220D5F7777A611B1C38BAE260B175951DC8060C2B3E0165961645281A8626607B917F657D7E9382F1EA5CD931 +F40F6627F357542653B201686522130D590FB8DE635D8FCA715CC6BF3D05BEF3F75DA5D543454448166612`, + "26352AF82EC19F207BBC6F9474E11E90CE0F7DDACE03B27F801817E897A81FD5", + }, + { + `0090414C494345313233405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6 +D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD4 +3D0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A23099093BF3C137D8FCBBCDF4A2AE50F3B0F216C3122D79425FE03A +45DBFE16553DF79E8DAC1CF0ECBAA2F2B49D51A4B387F2EFAF482339086A27A8E05BAED98B`, + "E4D1D0C3CA4C7F11BC8FF8CB3F4C02A78F108FA098E51A668487240F75E20F31", + }, + { + `008842494C4C343536405941484F4F2E434F4D787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E49863E4C6D3 +B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D +0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2245493D446C38D8CC0F118374690E7DF633A8A4BFB3329B5ECE604B2 +B4F37F4353C0869F4B9E17773DE68FEC45E14904E0DEA45BF6CECF9918C85EA047C60A4C`, + "6B4B6D0E276691BD4A11BF72F4FB501AE309FDACB72FA6CC336E6656119ABD67", + }, + { + `4D38D2958CA7FD2CFAE3AF04486959CF92C8EF48E8B83A05C112E739D5F181D03082020CA003020102020900AF28725D98D33143300C06 +082A811CCF550183750500307D310B300906035504060C02636E310B300906035504080C02626A310B300906035504070C02626A310F300D06035504 +0A0C06746F70736563310F300D060355040B0C06746F707365633111300F06035504030C08546F707365634341311F301D06092A864886F70D010901 +0C10626A40746F707365632E636F6D2E636E301E170D3132303632343037353433395A170D3332303632303037353433395A307D310B300906035504 +060C02636E310B300906035504080C02626A310B300906035504070C02626A310F300D060355040A0C06746F70736563310F300D060355040B0C0674 +6F707365633111300F06035504030C08546F707365634341311F301D06092A864886F70D0109010C10626A40746F707365632E636F6D2E636E305930 +1306072A8648CE3D020106082A811CCF5501822D03420004D69C2F1EEC3BFB6B95B30C28085C77B125D77A9C39525D8190768F37D6B205B589DCD316 +BBE7D89A9DC21917F17799E698531F5E6E3E10BD31370B259C3F81C3A3733071300F0603551D130101FF040530030101FF301D0603551D0E04160414 +8E5D90347858BAAAD870D8BDFBA6A85E7B563B64301F0603551D230418301680148E5D90347858BAAAD870D8BDFBA6A85E7B563B64300B0603551D0F +040403020106301106096086480186F8420101040403020057`, + "C3B02E500A8B60B77DEDCF6F4C11BEF8D56E5CDE708C72065654FD7B2167915A", + }, } for _, tt := range testData { - buf, _ := hex.DecodeString(tt.in) - got := Sum(buf) + buf, _ := hex.DecodeString(strings.ReplaceAll(tt.in, "\n", "")) + got := sm3.Sum(buf) expected, _ := hex.DecodeString(tt.out) if !bytes.Equal(expected, got[:]) { @@ -44,25 +92,29 @@ func TestSM3(t *testing.T) { type sm3func func([]byte) func benchmarkSM3(b *testing.B, length int64, fn sm3func) { + b.Helper() + buf := make([]byte, length) if _, err := io.ReadFull(rand.Reader, buf); err != nil { b.Fatal(err) } + b.SetBytes(length) b.ResetTimer() + for i := 0; i < b.N; i++ { fn(buf) } } func BenchmarkSM3Large(b *testing.B) { - benchmarkSM3(b, 1024*1024, func(buf []byte) { Sum(buf) }) + benchmarkSM3(b, 1024*1024, func(buf []byte) { sm3.Sum(buf) }) } func BenchmarkSM3Normal(b *testing.B) { - benchmarkSM3(b, 1024, func(buf []byte) { Sum(buf) }) + benchmarkSM3(b, 1024, func(buf []byte) { sm3.Sum(buf) }) } func BenchmarkSM3Small(b *testing.B) { - benchmarkSM3(b, 1, func(buf []byte) { Sum(buf) }) + benchmarkSM3(b, 1, func(buf []byte) { sm3.Sum(buf) }) } diff --git a/crypto/sm4/sm4.go b/crypto/sm4/sm4.go index 3006107..5583fe0 100644 --- a/crypto/sm4/sm4.go +++ b/crypto/sm4/sm4.go @@ -8,35 +8,30 @@ package sm4 // #include "../shim.h" -// #cgo linux LDFLAGS: -lcrypto -// #cgo darwin LDFLAGS: -lcrypto -// #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN -// #cgo windows pkg-config: libcrypto import "C" import ( "bytes" - "errors" "fmt" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) -type SM4Encrypter interface { +type Encrypter interface { // crypto.EncryptionCipherCtx SetPadding(pad bool) EncryptAll(input []byte) ([]byte, error) - SetAAD([]byte) - SetTagLen(int) + SetAAD(aad []byte) + SetTagLen(length int) GetTag() ([]byte, error) } -type SM4Decrypter interface { +type Decrypter interface { // crypto.DecryptionCipherCtx SetPadding(pad bool) DecryptAll(input []byte) ([]byte, error) - SetAAD([]byte) - SetTag([]byte) + SetAAD(aad []byte) + SetTag(tag []byte) } type sm4Encrypter struct { @@ -60,28 +55,32 @@ func getSM4Cipher(mode int) (*crypto.Cipher, error) { var err error switch mode { - case crypto.CIPHER_MODE_ECB: + case crypto.CipherModeECB: cipher, err = crypto.GetCipherByName("SM4-ECB") - case crypto.CIPHER_MODE_CBC: + case crypto.CipherModeCBC: cipher, err = crypto.GetCipherByName("SM4-CBC") - case crypto.CIPHER_MODE_CFB: + case crypto.CipherModeCFB: cipher, err = crypto.GetCipherByName("SM4-CFB") - case crypto.CIPHER_MODE_OFB: + case crypto.CipherModeOFB: cipher, err = crypto.GetCipherByName("SM4-OFB") - case crypto.CIPHER_MODE_CTR: + case crypto.CipherModeCTR: cipher, err = crypto.GetCipherByName("SM4-CTR") - case crypto.CIPHER_MODE_GCM: + case crypto.CipherModeGCM: cipher, err = crypto.GetCipherByName("SM4-GCM") - case crypto.CIPHER_MODE_CCM: + case crypto.CipherModeCCM: cipher, err = crypto.GetCipherByName("SM4-CCM") default: - return nil, errors.New("unsupported sm4 mode") + return nil, fmt.Errorf("unsupported sm4 mode: %w", crypto.ErrUnsupportedMode) } - return cipher, err + if err != nil { + return nil, fmt.Errorf("failed to get cipher: %w", err) + } + + return cipher, nil } -func NewSM4Decrypter(mode int, key []byte, iv []byte) (SM4Decrypter, error) { +func NewDecrypter(mode int, key []byte, iv []byte) (Decrypter, error) { cipher, err := getSM4Cipher(mode) if err != nil { return nil, err @@ -89,26 +88,26 @@ func NewSM4Decrypter(mode int, key []byte, iv []byte) (SM4Decrypter, error) { cctx, err := crypto.NewDecryptionCipherCtx(cipher, nil, nil, nil) if err != nil { - return nil, fmt.Errorf("failed to create decryption cipher ctx %s", err) + return nil, fmt.Errorf("failed to create decryption cipher ctx %w", err) } if len(iv) > 0 { - if mode == crypto.CIPHER_MODE_GCM || mode == crypto.CIPHER_MODE_CCM { + if mode == crypto.CipherModeGCM || mode == crypto.CipherModeCCM { err := cctx.SetCtrl(C.EVP_CTRL_AEAD_SET_IVLEN, len(iv)) if err != nil { - return nil, fmt.Errorf("failed to set IV len to %d: %s", len(iv), err) + return nil, fmt.Errorf("failed to set IV len to %d: %w", len(iv), err) } } } - return &sm4Decrypter{cctx: cctx, key: key, iv: iv}, nil + return &sm4Decrypter{cctx: cctx, key: key, iv: iv, aad: nil, tag: nil}, nil } func (ctx *sm4Decrypter) SetPadding(pad bool) { ctx.cctx.SetPadding(pad) } -func NewSM4Encrypter(mode int, key []byte, iv []byte) (SM4Encrypter, error) { +func NewEncrypter(mode int, key []byte, iv []byte) (Encrypter, error) { var tagLen int cipher, err := getSM4Cipher(mode) @@ -116,36 +115,41 @@ func NewSM4Encrypter(mode int, key []byte, iv []byte) (SM4Encrypter, error) { return nil, err } - if mode == crypto.CIPHER_MODE_GCM { + if mode == crypto.CipherModeGCM { tagLen = 16 } - if mode == crypto.CIPHER_MODE_CCM { + if mode == crypto.CipherModeCCM { tagLen = 12 } cctx, err := crypto.NewEncryptionCipherCtx(cipher, nil, nil, nil) if err != nil { - return nil, fmt.Errorf("failed to create encryption cipher ctx %s", err) + return nil, fmt.Errorf("failed to create encryption cipher ctx %w", err) } if len(iv) > 0 { - if mode == crypto.CIPHER_MODE_GCM || mode == crypto.CIPHER_MODE_CCM { + if mode == crypto.CipherModeGCM || mode == crypto.CipherModeCCM { err := cctx.SetCtrl(C.EVP_CTRL_AEAD_SET_IVLEN, len(iv)) if err != nil { - return nil, fmt.Errorf("could not set IV len to %d: %s", len(iv), err) + return nil, fmt.Errorf("could not set IV len to %d: %w", len(iv), err) } } } - return &sm4Encrypter{cctx: cctx, tagLen: tagLen, key: key, iv: iv}, nil + return &sm4Encrypter{cctx: cctx, tagLen: tagLen, key: key, iv: iv, aad: nil}, nil } func (ctx *sm4Encrypter) GetTag() ([]byte, error) { - return ctx.cctx.GetCtrlBytes(C.EVP_CTRL_AEAD_GET_TAG, ctx.tagLen, ctx.tagLen) + tag, err := ctx.cctx.GetCtrlBytes(C.EVP_CTRL_AEAD_GET_TAG, ctx.tagLen, ctx.tagLen) + if err != nil { + return nil, fmt.Errorf("failed to get tag: %w", err) + } + + return tag, nil } -func (ctx *sm4Encrypter) SetTagLen(len int) { - ctx.tagLen = len +func (ctx *sm4Encrypter) SetTagLen(length int) { + ctx.tagLen = length } func (ctx *sm4Encrypter) SetPadding(pad bool) { @@ -168,89 +172,92 @@ func (ctx *sm4Decrypter) DecryptAll(src []byte) ([]byte, error) { if ctx.tag != nil { err := ctx.cctx.SetCtrlBytes(C.EVP_CTRL_AEAD_SET_TAG, len(ctx.tag), ctx.tag) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to set tag: %w", err) } } err := ctx.cctx.SetKeyAndIV(ctx.key, ctx.iv) if err != nil { - return nil, fmt.Errorf("failed to set key or iv:%s", err) + return nil, fmt.Errorf("failed to set key or iv: %w", err) } var tmplen C.int if ctx.aad != nil { - is_ccm := (C.EVP_CIPHER_flags(C.X_EVP_CIPHER_CTX_cipher((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()))) & C.EVP_CIPH_MODE) == C.EVP_CIPH_CCM_MODE + isCcm := (C.EVP_CIPHER_flags(C.X_EVP_CIPHER_CTX_cipher((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()))) & + C.EVP_CIPH_MODE) == C.EVP_CIPH_CCM_MODE - if is_ccm { + if isCcm { res := C.EVP_DecryptUpdate((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()), nil, &tmplen, nil, C.int(len(src))) if res != 1 { - return nil, fmt.Errorf("failed to set CCM plain text length [result %d]", res) + return nil, fmt.Errorf("failed to set CCM plain text length: %w", crypto.PopError()) } } - res := C.EVP_DecryptUpdate((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()), nil, &tmplen, (*C.uchar)(&ctx.aad[0]), C.int(len(ctx.aad))) + res := C.EVP_DecryptUpdate((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()), nil, &tmplen, (*C.uchar)(&ctx.aad[0]), + C.int(len(ctx.aad))) if res != 1 { - return nil, fmt.Errorf("failed to set CCM AAD [result %d]", res) + return nil, fmt.Errorf("failed to decrypt: %w", crypto.PopError()) } } res := new(bytes.Buffer) buf, err := ctx.cctx.DecryptUpdate(src) if err != nil { - return nil, fmt.Errorf("Failed to perform decryption: %s", err) + return nil, fmt.Errorf("failed to perform decryption: %w", err) } res.Write(buf) buf2, err := ctx.cctx.DecryptFinal() if err != nil { - return nil, fmt.Errorf("Failed to finalize decryption: %s", err) + return nil, fmt.Errorf("failed to finalize decryption: %w", err) } res.Write(buf2) return res.Bytes(), nil } -func (sm4 *sm4Encrypter) EncryptAll(src []byte) ([]byte, error) { - - is_ccm := (C.EVP_CIPHER_flags(C.X_EVP_CIPHER_CTX_cipher((*C.EVP_CIPHER_CTX)(sm4.cctx.Ctx()))) & C.EVP_CIPH_MODE) == C.EVP_CIPH_CCM_MODE +func (ctx *sm4Encrypter) EncryptAll(src []byte) ([]byte, error) { + isCcm := (C.EVP_CIPHER_flags(C.X_EVP_CIPHER_CTX_cipher((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()))) & C.EVP_CIPH_MODE) == + C.EVP_CIPH_CCM_MODE - if is_ccm { - err := sm4.cctx.SetCtrl(C.EVP_CTRL_AEAD_SET_TAG, sm4.tagLen) + if isCcm { + err := ctx.cctx.SetCtrl(C.EVP_CTRL_AEAD_SET_TAG, ctx.tagLen) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to set CCM tag: %w", err) } } - err := sm4.cctx.SetKeyAndIV(sm4.key, sm4.iv) + err := ctx.cctx.SetKeyAndIV(ctx.key, ctx.iv) if err != nil { - return nil, fmt.Errorf("failed to set key or iv:%s", err) + return nil, fmt.Errorf("failed to set key or iv: %w", err) } var tmplen C.int - if sm4.aad != nil { - if is_ccm { - res := C.EVP_EncryptUpdate((*C.EVP_CIPHER_CTX)(sm4.cctx.Ctx()), nil, &tmplen, nil, C.int(len(src))) + if ctx.aad != nil { + if isCcm { + res := C.EVP_EncryptUpdate((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()), nil, &tmplen, nil, C.int(len(src))) if res != 1 { - return nil, fmt.Errorf("failed to set CCM plain text length [result %d]", res) + return nil, fmt.Errorf("failed to set CCM plain text length: %w", crypto.PopError()) } } - res := C.EVP_EncryptUpdate((*C.EVP_CIPHER_CTX)(sm4.cctx.Ctx()), nil, &tmplen, (*C.uchar)(&sm4.aad[0]), C.int(len(sm4.aad))) + res := C.EVP_EncryptUpdate((*C.EVP_CIPHER_CTX)(ctx.cctx.Ctx()), nil, &tmplen, (*C.uchar)(&ctx.aad[0]), + C.int(len(ctx.aad))) if res != 1 { - return nil, fmt.Errorf("failed to set AAD [result %d]", res) + return nil, fmt.Errorf("failed to set AAD: %w", crypto.PopError()) } } res := new(bytes.Buffer) - buf, err := sm4.cctx.EncryptUpdate(src) + buf, err := ctx.cctx.EncryptUpdate(src) if err != nil { - return nil, fmt.Errorf("Failed to perform encryption: %s", err) + return nil, fmt.Errorf("failed to perform encryption: %w", err) } res.Write(buf) - buf2, err := sm4.cctx.EncryptFinal() + buf2, err := ctx.cctx.EncryptFinal() if err != nil { - return nil, fmt.Errorf("Failed to finalize encryption: %s", err) + return nil, fmt.Errorf("failed to finalize encryption: %w", err) } res.Write(buf2) diff --git a/crypto/sm4/sm4_test.go b/crypto/sm4/sm4_test.go index 3f0264c..a81b71c 100644 --- a/crypto/sm4/sm4_test.go +++ b/crypto/sm4/sm4_test.go @@ -5,18 +5,25 @@ // in the file LICENSE in the source distribution or at // https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE -package sm4 +package sm4_test import ( "bytes" "encoding/hex" + "strings" "testing" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" + "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sm4" ) -func doEncrypt(mode int, key, iv, plainText, cipherText []byte, t *testing.T) { - enc, err := NewSM4Encrypter(mode, key, iv) +const hexPlainText1 = `AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFE +EEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA` + +func doEncrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) { + t.Helper() + + enc, err := sm4.NewEncrypter(mode, key, iv) if err != nil { t.Fatal("failed to create encrypter: ", err) } @@ -33,8 +40,10 @@ func doEncrypt(mode int, key, iv, plainText, cipherText []byte, t *testing.T) { } } -func doDecrypt(mode int, key, iv, plainText, cipherText []byte, t *testing.T) { - dec, err := NewSM4Decrypter(mode, key, iv) +func doDecrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) { + t.Helper() + + dec, err := sm4.NewDecrypter(mode, key, iv) if err != nil { t.Fatal("failed to create decrypter: ", err) } @@ -51,8 +60,10 @@ func doDecrypt(mode int, key, iv, plainText, cipherText []byte, t *testing.T) { } } -func doAEADEncrypt(mode int, key, iv, aad, tag, plainText, cipherText []byte, t *testing.T) { - enc, err := NewSM4Encrypter(mode, key, iv) +func doAEADEncrypt(t *testing.T, mode int, key, iv, aad, tag, plainText, cipherText []byte) { + t.Helper() + + enc, err := sm4.NewEncrypter(mode, key, iv) if err != nil { t.Fatal("failed to create encrypter: ", err) } @@ -79,8 +90,10 @@ func doAEADEncrypt(mode int, key, iv, aad, tag, plainText, cipherText []byte, t } } -func doAEADDecrypt(mode int, key, iv, aad, tag, plainText, cipherText []byte, t *testing.T) { - dec, err := NewSM4Decrypter(mode, key, iv) +func doAEADDecrypt(t *testing.T, mode int, key, iv, aad, tag, plainText, cipherText []byte) { + t.Helper() + + dec, err := sm4.NewDecrypter(mode, key, iv) if err != nil { t.Fatal("failed to create decrypter: ", err) } @@ -99,74 +112,94 @@ func doAEADDecrypt(mode int, key, iv, aad, tag, plainText, cipherText []byte, t } func TestSM4ECB(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("681EDF34D206965E86B3E94F536E4246") - doEncrypt(crypto.CIPHER_MODE_ECB, key, nil, plainText, cipherText, t) - doDecrypt(crypto.CIPHER_MODE_ECB, key, nil, plainText, cipherText, t) + doEncrypt(t, crypto.CipherModeECB, key, nil, plainText, cipherText) + doDecrypt(t, crypto.CipherModeECB, key, nil, plainText, cipherText) } func TestSM4CBC(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - doEncrypt(crypto.CIPHER_MODE_CBC, key, iv, plainText, cipherText, t) - doDecrypt(crypto.CIPHER_MODE_CBC, key, iv, plainText, cipherText, t) + doEncrypt(t, crypto.CipherModeCBC, key, iv, plainText, cipherText) + doDecrypt(t, crypto.CipherModeCBC, key, iv, plainText, cipherText) } func TestSM4CFB(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A70569ED258A85A0467CC92AAB393DD978995") - doEncrypt(crypto.CIPHER_MODE_CFB, key, iv, plainText, cipherText, t) - doDecrypt(crypto.CIPHER_MODE_CFB, key, iv, plainText, cipherText, t) + doEncrypt(t, crypto.CipherModeCFB, key, iv, plainText, cipherText) + doDecrypt(t, crypto.CipherModeCFB, key, iv, plainText, cipherText) } func TestSM4OFB(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A7056F2075D28B5235F58D50027E4177D2BCE") - doEncrypt(crypto.CIPHER_MODE_OFB, key, iv, plainText, cipherText, t) - doDecrypt(crypto.CIPHER_MODE_OFB, key, iv, plainText, cipherText, t) + doEncrypt(t, crypto.CipherModeOFB, key, iv, plainText, cipherText) + doDecrypt(t, crypto.CipherModeOFB, key, iv, plainText, cipherText) } func TestSM4CTR(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") - plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") - cipherText, _ := hex.DecodeString("C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638B361FD5FFACD942F081485A83CA35D") + hexCipherText := `C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638 +B361FD5FFACD942F081485A83CA35D` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) - doEncrypt(crypto.CIPHER_MODE_CTR, key, iv, plainText, cipherText, t) - doDecrypt(crypto.CIPHER_MODE_CTR, key, iv, plainText, cipherText, t) + doEncrypt(t, crypto.CipherModeCTR, key, iv, plainText, cipherText) + doDecrypt(t, crypto.CipherModeCTR, key, iv, plainText, cipherText) } func TestSM4GCM(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") iv, _ := hex.DecodeString("00001234567800000000ABCD") aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2") tag, _ := hex.DecodeString("83DE3541E4C2B58177E065A9BF7B62EC") - plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") - cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") + hexCipherText := `17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A5 +6834CBCF98C397B4024A2691233B8D` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) - doAEADEncrypt(crypto.CIPHER_MODE_GCM, key, iv, aad, tag, plainText, cipherText, t) - doAEADDecrypt(crypto.CIPHER_MODE_GCM, key, iv, aad, tag, plainText, cipherText, t) + doAEADEncrypt(t, crypto.CipherModeGCM, key, iv, aad, tag, plainText, cipherText) + doAEADDecrypt(t, crypto.CipherModeGCM, key, iv, aad, tag, plainText, cipherText) } func TestSM4CCM(t *testing.T) { + t.Parallel() + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") iv, _ := hex.DecodeString("00001234567800000000ABCD") aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2") tag, _ := hex.DecodeString("16842D4FA186F56AB33256971FA110F4") - plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") - cipherText, _ := hex.DecodeString("48AF93501FA62ADBCD414CCE6034D895DDA1BF8F132F042098661572E7483094FD12E518CE062C98ACEE28D95DF4416BED31A2F04476C18BB40C84A74B97DC5B") + hexCipherText := `48AF93501FA62ADBCD414CCE6034D895DDA1BF8F132F042098661572E7483094FD12E518CE062C98ACEE28D95DF4416BED +31A2F04476C18BB40C84A74B97DC5B` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) - doAEADEncrypt(crypto.CIPHER_MODE_CCM, key, iv, aad, tag, plainText, cipherText, t) - doAEADDecrypt(crypto.CIPHER_MODE_CCM, key, iv, aad, tag, plainText, cipherText, t) + doAEADEncrypt(t, crypto.CipherModeCCM, key, iv, aad, tag, plainText, cipherText) + doAEADDecrypt(t, crypto.CipherModeCCM, key, iv, aad, tag, plainText, cipherText) } diff --git a/ctx.go b/ctx.go index ee1505e..b182d6a 100644 --- a/ctx.go +++ b/ctx.go @@ -20,7 +20,6 @@ import "C" import ( "errors" "fmt" - "io/ioutil" "os" "runtime" "sync" @@ -30,47 +29,46 @@ import ( "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) -var ( - ssl_ctx_idx = C.X_SSL_CTX_new_index() -) +var sslCtxIdx = C.X_SSL_CTX_new_index() type Ctx struct { ctx *C.SSL_CTX cert *crypto.Certificate chain []*crypto.Certificate - key crypto.PrivateKey - verify_cb VerifyCallback - sni_cb TLSExtServernameCallback - alpn_cb TLSExtAlpnCallback + key crypto.PrivateKey + verifyCb VerifyCallback + sniCb TLSExtServernameCallback + alpnCb TLSExtAlpnCallback encCert *crypto.Certificate encKey crypto.PrivateKey - ticket_store_mu sync.Mutex - ticket_store *TicketStore + ticketStoreMu sync.Mutex + ticketStore *TicketStore } //export get_ssl_ctx_idx func get_ssl_ctx_idx() C.int { - return ssl_ctx_idx + return sslCtxIdx } func newCtx(method *C.SSL_METHOD) (*Ctx, error) { runtime.LockOSThread() defer runtime.UnlockOSThread() - ctx := C.SSL_CTX_new(method) - if ctx == nil { - return nil, crypto.ErrorFromErrorQueue() + sslCtx := C.SSL_CTX_new(method) + if sslCtx == nil { + return nil, fmt.Errorf("failed to create SSL CTX: %w", crypto.PopError()) } - c := &Ctx{ctx: ctx} + ctx := &Ctx{ctx: sslCtx} // Bypass go vet check, possibly passing Go type with embedded pointer to C - var p (*C.char) = (*C.char)(unsafe.Pointer(c)) - C.SSL_CTX_set_ex_data(ctx, get_ssl_ctx_idx(), unsafe.Pointer(p)) - runtime.SetFinalizer(c, func(c *Ctx) { + var p (*C.char) = (*C.char)(unsafe.Pointer(ctx)) + C.SSL_CTX_set_ex_data(sslCtx, get_ssl_ctx_idx(), unsafe.Pointer(p)) + runtime.SetFinalizer(ctx, func(c *Ctx) { C.SSL_CTX_free(c.ctx) }) - return c, nil + + return ctx, nil } type SSLVersion int @@ -100,25 +98,27 @@ func NewCtxWithVersion(version SSLVersion) (*Ctx, error) { method = C.TLS_method() } if method == nil { - return nil, errors.New("unknown ssl/tls version") + return nil, fmt.Errorf("unknown ssl/tls version: %w", crypto.ErrUnknownTLSVersion) } - c, err := newCtx(method) + ctx, err := newCtx(method) if err != nil { return nil, err } if enableNTLS { - C.X_SSL_CTX_enable_ntls(c.ctx) - } else if version == AnyVersion { - C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(TLSv1)) - C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(TLSv1_3)) + C.X_SSL_CTX_enable_ntls(ctx.ctx) + } + + if version == AnyVersion { + C.X_SSL_CTX_set_min_proto_version(ctx.ctx, C.int(TLSv1)) + C.X_SSL_CTX_set_max_proto_version(ctx.ctx, C.int(TLSv1_3)) } else { - C.X_SSL_CTX_set_min_proto_version(c.ctx, C.int(version)) - C.X_SSL_CTX_set_max_proto_version(c.ctx, C.int(version)) + C.X_SSL_CTX_set_min_proto_version(ctx.ctx, C.int(version)) + C.X_SSL_CTX_set_max_proto_version(ctx.ctx, C.int(version)) } - return c, nil + return ctx, nil } // NewCtx creates a context that supports any TLS version 1.0 and newer. @@ -132,25 +132,25 @@ func NewCtx() (*Ctx, error) { // NewCtxFromFiles calls NewCtx, loads the provided files, and configures the // context to use them. -func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { +func NewCtxFromFiles(certFile string, keyFile string) (*Ctx, error) { ctx, err := NewCtx() if err != nil { return nil, err } - cert_bytes, err := ioutil.ReadFile(cert_file) + certBytes, err := os.ReadFile(certFile) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read file: %w", err) } - certs := SplitPEM(cert_bytes) + certs := SplitPEM(certBytes) if len(certs) == 0 { - return nil, fmt.Errorf("No PEM certificate found in '%s'", cert_file) + return nil, fmt.Errorf("no PEM certificate found in '%s': %w", certFile, crypto.ErrNoCert) } first, certs := certs[0], certs[1:] cert, err := crypto.LoadCertificateFromPEM(first) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load cert from pem: %w", err) } err = ctx.UseCertificate(cert) @@ -161,7 +161,7 @@ func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { for _, pem := range certs { cert, err := crypto.LoadCertificateFromPEM(pem) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load cert from pem: %w", err) } err = ctx.AddChainCertificate(cert) if err != nil { @@ -169,14 +169,14 @@ func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { } } - key_bytes, err := ioutil.ReadFile(key_file) + keyBytes, err := os.ReadFile(keyFile) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read file: %w", err) } - key, err := crypto.LoadPrivateKeyFromPEM(key_bytes) + key, err := crypto.LoadPrivateKeyFromPEM(keyBytes) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private key from pem: %w", err) } err = ctx.UsePrivateKey(key) @@ -189,18 +189,18 @@ func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { // SetEllipticCurve sets the elliptic curve used by the SSL context to // enable an ECDH cipher suite to be selected during the handshake. -func (c *Ctx) SetEllipticCurve(curve crypto.EllipticCurve) error { +func (ctx *Ctx) SetEllipticCurve(curve crypto.EllipticCurve) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - k := C.EC_KEY_new_by_curve_name(C.int(curve)) - if k == nil { - return errors.New("Unknown curve") + key := C.EC_KEY_new_by_curve_name(C.int(curve)) + if key == nil { + return fmt.Errorf("failed to create ec key: %w", crypto.PopError()) } - defer C.EC_KEY_free(k) + defer C.EC_KEY_free(key) - if int(C.X_SSL_CTX_set_tmp_ecdh(c.ctx, k)) != 1 { - return crypto.ErrorFromErrorQueue() + if int(C.X_SSL_CTX_set_tmp_ecdh(ctx.ctx, key)) != 1 { + return fmt.Errorf("failed to set temp ecdh: %w", crypto.PopError()) } return nil @@ -208,48 +208,48 @@ func (c *Ctx) SetEllipticCurve(curve crypto.EllipticCurve) error { // UseSignCertificate configures the context to present the given sign certificate to // peers. -func (c *Ctx) UseSignCertificate(cert *crypto.Certificate) error { +func (ctx *Ctx) UseSignCertificate(cert *crypto.Certificate) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.cert = cert - if int(C.SSL_CTX_use_sign_certificate(c.ctx, (*C.X509)(cert.GetCert()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.cert = cert + if int(C.SSL_CTX_use_sign_certificate(ctx.ctx, (*C.X509)(cert.GetCert()))) != 1 { + return fmt.Errorf("failed to set sign cert: %w", crypto.PopError()) } return nil } // UseEncryptCertificate configures the context to present the given encrypt certificate to // peers. -func (c *Ctx) UseEncryptCertificate(cert *crypto.Certificate) error { +func (ctx *Ctx) UseEncryptCertificate(cert *crypto.Certificate) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.encCert = cert - if int(C.SSL_CTX_use_enc_certificate(c.ctx, (*C.X509)(cert.GetCert()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.encCert = cert + if int(C.SSL_CTX_use_enc_certificate(ctx.ctx, (*C.X509)(cert.GetCert()))) != 1 { + return fmt.Errorf("failed to set enc cert: %w", crypto.PopError()) } return nil } // UseCertificate configures the context to present the given certificate to // peers. -func (c *Ctx) UseCertificate(cert *crypto.Certificate) error { +func (ctx *Ctx) UseCertificate(cert *crypto.Certificate) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.cert = cert - if int(C.SSL_CTX_use_certificate(c.ctx, (*C.X509)(cert.GetCert()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.cert = cert + if int(C.SSL_CTX_use_certificate(ctx.ctx, (*C.X509)(cert.GetCert()))) != 1 { + return fmt.Errorf("failed to set cert: %w", crypto.PopError()) } return nil } // AddChainCertificate adds a certificate to the chain presented in the // handshake. -func (c *Ctx) AddChainCertificate(cert *crypto.Certificate) error { +func (ctx *Ctx) AddChainCertificate(cert *crypto.Certificate) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.chain = append(c.chain, cert) - if int(C.X_SSL_CTX_add_extra_chain_cert(c.ctx, (*C.X509)(cert.GetCert()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.chain = append(ctx.chain, cert) + if int(C.X_SSL_CTX_add_extra_chain_cert(ctx.ctx, (*C.X509)(cert.GetCert()))) != 1 { + return fmt.Errorf("failed to set chain cert: %w", crypto.PopError()) } // OpenSSL takes ownership via SSL_CTX_add_extra_chain_cert runtime.SetFinalizer(cert, nil) @@ -258,36 +258,36 @@ func (c *Ctx) AddChainCertificate(cert *crypto.Certificate) error { // UseSignPrivateKey configures the context to use the given sign private key for SSL // handshakes. -func (c *Ctx) UseSignPrivateKey(key crypto.PrivateKey) error { +func (ctx *Ctx) UseSignPrivateKey(key crypto.PrivateKey) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.key = key - if int(C.SSL_CTX_use_sign_PrivateKey(c.ctx, (*C.EVP_PKEY)(key.EvpPKey()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.key = key + if int(C.SSL_CTX_use_sign_PrivateKey(ctx.ctx, (*C.EVP_PKEY)(key.EvpPKey()))) != 1 { + return fmt.Errorf("failed to set sign private key: %w", crypto.PopError()) } return nil } // UseEncryptPrivateKey configures the context to use the given encrypt private key for SSL // handshakes. -func (c *Ctx) UseEncryptPrivateKey(key crypto.PrivateKey) error { +func (ctx *Ctx) UseEncryptPrivateKey(key crypto.PrivateKey) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.encKey = key - if int(C.SSL_CTX_use_enc_PrivateKey(c.ctx, (*C.EVP_PKEY)(key.EvpPKey()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.encKey = key + if int(C.SSL_CTX_use_enc_PrivateKey(ctx.ctx, (*C.EVP_PKEY)(key.EvpPKey()))) != 1 { + return fmt.Errorf("failed to set enc private key: %w", crypto.PopError()) } return nil } // UsePrivateKey configures the context to use the given private key for SSL // handshakes. -func (c *Ctx) UsePrivateKey(key crypto.PrivateKey) error { +func (ctx *Ctx) UsePrivateKey(key crypto.PrivateKey) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - c.key = key - if int(C.SSL_CTX_use_PrivateKey(c.ctx, (*C.EVP_PKEY)(key.EvpPKey()))) != 1 { - return crypto.ErrorFromErrorQueue() + ctx.key = key + if int(C.SSL_CTX_use_PrivateKey(ctx.ctx, (*C.EVP_PKEY)(key.EvpPKey()))) != 1 { + return fmt.Errorf("failed to set private key: %w", crypto.PopError()) } return nil } @@ -303,9 +303,9 @@ type CertificateStore struct { func NewCertificateStore() (*CertificateStore, error) { s := C.X509_STORE_new() if s == nil { - return nil, errors.New("failed to allocate X509_STORE") + return nil, fmt.Errorf("failed to create X509_STORE: %w", crypto.PopError()) } - store := &CertificateStore{store: s} + store := &CertificateStore{store: s, ctx: nil, certs: nil} runtime.SetFinalizer(store, func(s *CertificateStore) { C.X509_STORE_free(s.store) }) @@ -318,7 +318,7 @@ func (s *CertificateStore) LoadCertificatesFromPEM(data []byte) error { for _, pem := range pems { cert, err := crypto.LoadCertificateFromPEM(pem) if err != nil { - return err + return fmt.Errorf("failed to load cert from pem: %w", err) } err = s.AddCertificate(cert) if err != nil { @@ -330,22 +330,24 @@ func (s *CertificateStore) LoadCertificatesFromPEM(data []byte) error { // GetCertificateStore returns the context's certificate store that will be // used for peer validation. -func (c *Ctx) GetCertificateStore() *CertificateStore { +func (ctx *Ctx) GetCertificateStore() *CertificateStore { // we don't need to dealloc the cert store pointer here, because it points // to a ctx internal. so we do need to keep the ctx around return &CertificateStore{ - store: C.SSL_CTX_get_cert_store(c.ctx), - ctx: c} + store: C.SSL_CTX_get_cert_store(ctx.ctx), + ctx: ctx, + certs: nil, + } } // SetDHParameters sets the DH group (DH parameters) used to // negotiate an emphemeral DH key during handshaking. -func (c *Ctx) SetDHParameters(dh *crypto.DH) error { +func (ctx *Ctx) SetDHParameters(dh *crypto.DH) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - if int(C.X_SSL_CTX_set_tmp_dh(c.ctx, (*C.DH)(dh.GetDH()))) != 1 { - return crypto.ErrorFromErrorQueue() + if int(C.X_SSL_CTX_set_tmp_dh(ctx.ctx, (*C.DH)(dh.GetDH()))) != 1 { + return fmt.Errorf("failed to set temp dh: %w", crypto.PopError()) } return nil } @@ -357,42 +359,42 @@ func (s *CertificateStore) AddCertificate(cert *crypto.Certificate) error { defer runtime.UnlockOSThread() s.certs = append(s.certs, cert) if int(C.X509_STORE_add_cert(s.store, (*C.X509)(cert.GetCert()))) != 1 { - return crypto.ErrorFromErrorQueue() + return fmt.Errorf("failed to add cert: %w", crypto.PopError()) } return nil } type CertificateStoreCtx struct { - ctx *C.X509_STORE_CTX - ssl_ctx *Ctx + ctx *C.X509_STORE_CTX + sslCtx *Ctx } -func (self *CertificateStoreCtx) VerifyResult() VerifyResult { - return VerifyResult(C.X509_STORE_CTX_get_error(self.ctx)) +func (ctx *CertificateStoreCtx) VerifyResult() VerifyResult { + return VerifyResult(C.X509_STORE_CTX_get_error(ctx.ctx)) } -func (self *CertificateStoreCtx) Err() error { - code := C.X509_STORE_CTX_get_error(self.ctx) +func (ctx *CertificateStoreCtx) Err() error { + code := C.X509_STORE_CTX_get_error(ctx.ctx) if code == C.X509_V_OK { return nil } - return fmt.Errorf("openssl: %s", - C.GoString(C.X509_verify_cert_error_string(C.long(code)))) + + return errors.New("x509 verify error: " + C.GoString(C.X509_verify_cert_error_string(C.long(code)))) } -func (self *CertificateStoreCtx) Depth() int { - return int(C.X509_STORE_CTX_get_error_depth(self.ctx)) +func (ctx *CertificateStoreCtx) Depth() int { + return int(C.X509_STORE_CTX_get_error_depth(ctx.ctx)) } // GetCurrentCert the certicate returned is only valid for the lifetime of the underlying // X509_STORE_CTX -func (self *CertificateStoreCtx) GetCurrentCert() *crypto.Certificate { - x509 := C.X509_STORE_CTX_get_current_cert(self.ctx) +func (ctx *CertificateStoreCtx) GetCurrentCert() *crypto.Certificate { + x509 := C.X509_STORE_CTX_get_current_cert(ctx.ctx) if x509 == nil { return nil } // add a ref - if 1 != C.X_X509_add_ref(x509) { + if C.X_X509_add_ref(x509) != 1 { return nil } cert := crypto.NewCertWrapper((unsafe.Pointer(x509))) @@ -406,34 +408,34 @@ func (self *CertificateStoreCtx) GetCurrentCert() *crypto.Certificate { // provided in either the ca_file or the ca_path. // See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for // more. -func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error { +func (ctx *Ctx) LoadVerifyLocations(caFile string, caPath string) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - var c_ca_file, c_ca_path *C.char + var cCaFile, cCaPath *C.char - if ca_path == "" && ca_file == "" { - if C.SSL_CTX_set_default_verify_file(c.ctx) <= 0 { - return crypto.ErrorFromErrorQueue() + if caPath == "" && caFile == "" { + if C.SSL_CTX_set_default_verify_file(ctx.ctx) <= 0 { + return fmt.Errorf("failed to set default verify file: %w", crypto.PopError()) } - if C.SSL_CTX_set_default_verify_dir(c.ctx) <= 0 { - return crypto.ErrorFromErrorQueue() + if C.SSL_CTX_set_default_verify_dir(ctx.ctx) <= 0 { + return fmt.Errorf("failed to set default verify dir: %w", crypto.PopError()) } return nil } - if ca_file != "" { - c_ca_file = C.CString(ca_file) - defer C.free(unsafe.Pointer(c_ca_file)) + if caFile != "" { + cCaFile = C.CString(caFile) + defer C.free(unsafe.Pointer(cCaFile)) } - if ca_path != "" { - c_ca_path = C.CString(ca_path) - defer C.free(unsafe.Pointer(c_ca_path)) + if caPath != "" { + cCaPath = C.CString(caPath) + defer C.free(unsafe.Pointer(cCaPath)) } - if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) <= 0 { - return crypto.ErrorFromErrorQueue() + if C.SSL_CTX_load_verify_locations(ctx.ctx, cCaFile, cCaPath) <= 0 { + return fmt.Errorf("failed to load verify locations: %w", crypto.PopError()) } return nil } @@ -453,20 +455,20 @@ const ( // SetOptions sets context options. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html -func (c *Ctx) SetOptions(options Options) Options { +func (ctx *Ctx) SetOptions(options Options) Options { return Options(C.X_SSL_CTX_set_options( - c.ctx, C.long(options))) + ctx.ctx, C.long(options))) } -func (c *Ctx) ClearOptions(options Options) Options { +func (ctx *Ctx) ClearOptions(options Options) Options { return Options(C.X_SSL_CTX_clear_options( - c.ctx, C.long(options))) + ctx.ctx, C.long(options))) } // GetOptions returns context options. See // https://www.openssl.org/docs/ssl/SSL_CTX_set_options.html -func (c *Ctx) GetOptions() Options { - return Options(C.X_SSL_CTX_get_options(c.ctx)) +func (ctx *Ctx) GetOptions() Options { + return Options(C.X_SSL_CTX_get_options(ctx.ctx)) } type Modes int @@ -478,14 +480,14 @@ const ( // SetMode sets context modes. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html -func (c *Ctx) SetMode(modes Modes) Modes { - return Modes(C.X_SSL_CTX_set_mode(c.ctx, C.long(modes))) +func (ctx *Ctx) SetMode(modes Modes) Modes { + return Modes(C.X_SSL_CTX_set_mode(ctx.ctx, C.long(modes))) } // GetMode returns context modes. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html -func (c *Ctx) GetMode() Modes { - return Modes(C.X_SSL_CTX_get_mode(c.ctx)) +func (ctx *Ctx) GetMode() Modes { + return Modes(C.X_SSL_CTX_get_mode(ctx.ctx)) } type VerifyOptions int @@ -500,19 +502,18 @@ const ( type VerifyCallback func(ok bool, store *CertificateStoreCtx) bool //export go_ssl_ctx_verify_cb_thunk -func go_ssl_ctx_verify_cb_thunk(p unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CTX) C.int { +func go_ssl_ctx_verify_cb_thunk(callback unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CTX) C.int { defer func() { if err := recover(); err != nil { - // TODO logger - //logger.Critf("openssl: verify callback panic'd: %v", err) + // logger.Critf("openssl: verify callback panic'd: %v", err) os.Exit(1) } }() - verify_cb := (*Ctx)(p).verify_cb + verifyCb := (*Ctx)(callback).verifyCb // set up defaults just in case verify_cb is nil - if verify_cb != nil { - store := &CertificateStoreCtx{ctx: ctx} - if verify_cb(ok == 1, store) { + if verifyCb != nil { + store := &CertificateStoreCtx{ctx: ctx, sslCtx: nil} + if verifyCb(ok == 1, store) { ok = 1 } else { ok = 0 @@ -523,43 +524,43 @@ func go_ssl_ctx_verify_cb_thunk(p unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CT // SetVerify controls peer verification settings. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html -func (c *Ctx) SetVerify(options VerifyOptions, verify_cb VerifyCallback) { - c.verify_cb = verify_cb - if verify_cb != nil { - C.SSL_CTX_set_verify(c.ctx, C.int(options), (*[0]byte)(C.X_SSL_CTX_verify_cb)) +func (ctx *Ctx) SetVerify(options VerifyOptions, verifyCb VerifyCallback) { + ctx.verifyCb = verifyCb + if verifyCb != nil { + C.SSL_CTX_set_verify(ctx.ctx, C.int(options), (*[0]byte)(C.X_SSL_CTX_verify_cb)) } else { - C.SSL_CTX_set_verify(c.ctx, C.int(options), nil) + C.SSL_CTX_set_verify(ctx.ctx, C.int(options), nil) } } -func (c *Ctx) SetVerifyMode(options VerifyOptions) { - c.SetVerify(options, c.verify_cb) +func (ctx *Ctx) SetVerifyMode(options VerifyOptions) { + ctx.SetVerify(options, ctx.verifyCb) } -func (c *Ctx) SetVerifyCallback(verify_cb VerifyCallback) { - c.SetVerify(c.VerifyMode(), verify_cb) +func (ctx *Ctx) SetVerifyCallback(verifyCb VerifyCallback) { + ctx.SetVerify(ctx.VerifyMode(), verifyCb) } -func (c *Ctx) GetVerifyCallback() VerifyCallback { - return c.verify_cb +func (ctx *Ctx) GetVerifyCallback() VerifyCallback { + return ctx.verifyCb } -func (c *Ctx) VerifyMode() VerifyOptions { - return VerifyOptions(C.SSL_CTX_get_verify_mode(c.ctx)) +func (ctx *Ctx) VerifyMode() VerifyOptions { + return VerifyOptions(C.SSL_CTX_get_verify_mode(ctx.ctx)) } // SetVerifyDepth controls how many certificates deep the certificate // verification logic is willing to follow a certificate chain. See // https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html -func (c *Ctx) SetVerifyDepth(depth int) { - C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth)) +func (ctx *Ctx) SetVerifyDepth(depth int) { + C.SSL_CTX_set_verify_depth(ctx.ctx, C.int(depth)) } // GetVerifyDepth controls how many certificates deep the certificate // verification logic is willing to follow a certificate chain. See // https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html -func (c *Ctx) GetVerifyDepth() int { - return int(C.SSL_CTX_get_verify_depth(c.ctx)) +func (ctx *Ctx) GetVerifyDepth() int { + return int(C.SSL_CTX_get_verify_depth(ctx.ctx)) } type TLSExtServernameCallback func(ssl *SSL) SSLTLSExtErr @@ -567,12 +568,13 @@ type TLSExtServernameCallback func(ssl *SSL) SSLTLSExtErr // SetTLSExtServernameCallback sets callback function for Server Name Indication // (SNI) rfc6066 (http://tools.ietf.org/html/rfc6066). See // http://stackoverflow.com/questions/22373332/serving-multiple-domains-in-one-box-with-sni -func (c *Ctx) SetTLSExtServernameCallback(sni_cb TLSExtServernameCallback) { - c.sni_cb = sni_cb - C.X_SSL_CTX_set_tlsext_servername_callback(c.ctx, (*[0]byte)(C.sni_cb)) +func (ctx *Ctx) SetTLSExtServernameCallback(sniCb TLSExtServernameCallback) { + ctx.sniCb = sniCb + C.X_SSL_CTX_set_tlsext_servername_callback(ctx.ctx, (*[0]byte)(C.sni_cb)) } -type TLSExtAlpnCallback func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr +type TLSExtAlpnCallback func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, + arg unsafe.Pointer) SSLTLSExtErr // SetServerALPNProtos sets the ALPN protocol list, if failed the negotiation will lead to server handshake failure func (ctx *Ctx) SetServerALPNProtos(protos []string) { @@ -583,7 +585,11 @@ func (ctx *Ctx) SetServerALPNProtos(protos []string) { protoList = append(protoList, []byte(proto)...) // Add the protocol content } - ctx.alpn_cb = func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr { + ctx.alpnCb = func(_ *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, + arg unsafe.Pointer, + ) SSLTLSExtErr { + _ = arg // Unused + // Use OpenSSL function to select the protocol ret := C.SSL_select_next_proto( (**C.uchar)(out), @@ -594,7 +600,7 @@ func (ctx *Ctx) SetServerALPNProtos(protos []string) { C.uint(inlen), ) - if ret != OPENSSL_NPN_NEGOTIATED { + if ret != NPNNegotiated { return SSLTLSExtErrAlertFatal } @@ -619,21 +625,21 @@ func (ctx *Ctx) SetClientALPNProtos(protos []string) error { // Call the C function to set the ALPN protocols ret := C.SSL_CTX_set_alpn_protos(ctx.ctx, cProtoList, C.uint(len(protoList))) if ret != 0 { - return errors.New("failed to set ALPN protocols") + return fmt.Errorf("failed to set ALPN protocols: %w", crypto.PopError()) } return nil } -func (c *Ctx) SetSessionId(session_id []byte) error { +func (ctx *Ctx) SetSessionID(sessionID []byte) error { runtime.LockOSThread() defer runtime.UnlockOSThread() var ptr *C.uchar - if len(session_id) > 0 { - ptr = (*C.uchar)(unsafe.Pointer(&session_id[0])) + if len(sessionID) > 0 { + ptr = (*C.uchar)(unsafe.Pointer(&sessionID[0])) } - if int(C.SSL_CTX_set_session_id_context(c.ctx, ptr, - C.uint(len(session_id)))) == 0 { - return crypto.ErrorFromErrorQueue() + if int(C.SSL_CTX_set_session_id_context(ctx.ctx, ptr, + C.uint(len(sessionID)))) == 0 { + return fmt.Errorf("failed to set session id ctx: %w", crypto.PopError()) } return nil } @@ -641,28 +647,28 @@ func (c *Ctx) SetSessionId(session_id []byte) error { // SetCipherList sets the list of available ciphers. The format of the list is // described at http://www.openssl.org/docs/apps/ciphers.html, but see // http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more. -func (c *Ctx) SetCipherList(list string) error { +func (ctx *Ctx) SetCipherList(list string) error { runtime.LockOSThread() defer runtime.UnlockOSThread() clist := C.CString(list) defer C.free(unsafe.Pointer(clist)) - if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 { - return crypto.ErrorFromErrorQueue() + if int(C.SSL_CTX_set_cipher_list(ctx.ctx, clist)) == 0 { + return fmt.Errorf("failed to set cipher list: %w", crypto.PopError()) } return nil } -func (c *Ctx) SetCipherSuites(suites string) error { +func (ctx *Ctx) SetCipherSuites(suites string) error { runtime.LockOSThread() defer runtime.UnlockOSThread() csuits := C.CString(suites) defer C.free(unsafe.Pointer(csuits)) - if int(C.SSL_CTX_set_ciphersuites(c.ctx, csuits)) == 0 { - return crypto.ErrorFromErrorQueue() + if int(C.SSL_CTX_set_ciphersuites(ctx.ctx, csuits)) == 0 { + return fmt.Errorf("failed to set ciphersuites: %w", crypto.PopError()) } return nil } @@ -674,40 +680,36 @@ const ( SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH - NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR - NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP - NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE - NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL ) // SetSessionCacheMode enables or disables session caching. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html -func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes { +func (ctx *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes { return SessionCacheModes( - C.X_SSL_CTX_set_session_cache_mode(c.ctx, C.long(modes))) + C.X_SSL_CTX_set_session_cache_mode(ctx.ctx, C.long(modes))) } // Set session cache timeout. Returns previously set value. // See https://www.openssl.org/docs/ssl/SSL_CTX_set_timeout.html -func (c *Ctx) SetTimeout(t time.Duration) time.Duration { - prev := C.X_SSL_CTX_set_timeout(c.ctx, C.long(t/time.Second)) +func (ctx *Ctx) SetTimeout(t time.Duration) time.Duration { + prev := C.X_SSL_CTX_set_timeout(ctx.ctx, C.long(t/time.Second)) return time.Duration(prev) * time.Second } // GetTimeout Get session cache timeout. // See https://www.openssl.org/docs/ssl/SSL_CTX_set_timeout.html -func (c *Ctx) GetTimeout() time.Duration { - return time.Duration(C.X_SSL_CTX_get_timeout(c.ctx)) * time.Second +func (ctx *Ctx) GetTimeout() time.Duration { + return time.Duration(C.X_SSL_CTX_get_timeout(ctx.ctx)) * time.Second } // SessSetCacheSize Set session cache size. Returns previously set value. // https://www.openssl.org/docs/ssl/SSL_CTX_sess_set_cache_size.html -func (c *Ctx) SessSetCacheSize(t int) int { - return int(C.X_SSL_CTX_sess_set_cache_size(c.ctx, C.long(t))) +func (ctx *Ctx) SessSetCacheSize(t int) int { + return int(C.X_SSL_CTX_sess_set_cache_size(ctx.ctx, C.long(t))) } // SessGetCacheSize Get session cache size. // https://www.openssl.org/docs/ssl/SSL_CTX_sess_set_cache_size.html -func (c *Ctx) SessGetCacheSize() int { - return int(C.X_SSL_CTX_sess_get_cache_size(c.ctx)) +func (ctx *Ctx) SessGetCacheSize() int { + return int(C.X_SSL_CTX_sess_get_cache_size(ctx.ctx)) } diff --git a/ctx_test.go b/ctx_test.go index 8e2ca54..7f5be87 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -12,36 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package tongsuogo_test import ( "testing" "time" + + ts "github.com/tongsuo-project/tongsuo-go-sdk" ) func TestCtxTimeoutOption(t *testing.T) { - ctx, _ := NewCtx() + t.Parallel() + + ctx, _ := ts.NewCtx() oldTimeout1 := ctx.GetTimeout() newTimeout1 := oldTimeout1 + (time.Duration(99) * time.Second) oldTimeout2 := ctx.SetTimeout(newTimeout1) newTimeout2 := ctx.GetTimeout() + if oldTimeout1 != oldTimeout2 { t.Error("SetTimeout() returns something undocumented") } + if newTimeout1 != newTimeout2 { t.Error("SetTimeout() does not save anything to ctx") } } func TestCtxSessCacheSizeOption(t *testing.T) { - ctx, _ := NewCtx() + t.Parallel() + + ctx, _ := ts.NewCtx() oldSize1 := ctx.SessGetCacheSize() newSize1 := oldSize1 + 42 oldSize2 := ctx.SessSetCacheSize(newSize1) newSize2 := ctx.SessGetCacheSize() + if oldSize1 != oldSize2 { t.Error("SessSetCacheSize() returns something undocumented") } + if newSize1 != newSize2 { t.Error("SessSetCacheSize() does not save anything to ctx") } diff --git a/examples/cert_gen/main.go b/examples/cert_gen/main.go index b4aaf09..9bb14ce 100644 --- a/examples/cert_gen/main.go +++ b/examples/cert_gen/main.go @@ -1,10 +1,11 @@ package main import ( - "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "math/big" "path/filepath" "time" + + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) const genPath = "./examples/cert_gen/" @@ -12,18 +13,21 @@ const genPath = "./examples/cert_gen/" func main() { // Helper function: generate and save key generateAndSaveKey := func(filename string) crypto.PrivateKey { - key, err := crypto.GenerateECKey(crypto.Sm2Curve) + key, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { panic(err) } + pem, err := key.MarshalPKCS8PrivateKeyPEM() if err != nil { panic(err) } + err = crypto.SavePEMToFile(pem, filename) if err != nil { panic(err) } + return key } @@ -33,23 +37,27 @@ func main() { if err != nil { panic(err) } + err = cert.AddExtensions(extensions) if err != nil { panic(err) } + return cert } // Helper function: sign and save certificate signAndSaveCert := func(cert *crypto.Certificate, caKey crypto.PrivateKey, filename string) { - err := cert.Sign(caKey, crypto.EVP_SM3) + err := cert.Sign(caKey, crypto.MDSM3) if err != nil { panic(err) } + certPem, err := cert.MarshalPEM() if err != nil { panic(err) } + err = crypto.SavePEMToFile(certPem, filename) if err != nil { panic(err) @@ -57,10 +65,11 @@ func main() { } // Create CA certificate - caKey, err := crypto.GenerateECKey(crypto.Sm2Curve) + caKey, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { panic(err) } + caInfo := crypto.CertificateInfo{ Serial: big.NewInt(1), Expires: 87600 * time.Hour, // 10 years @@ -69,10 +78,10 @@ func main() { CommonName: "CA", } caExtensions := map[crypto.NID]string{ - crypto.NID_basic_constraints: "critical,CA:TRUE", - crypto.NID_key_usage: "critical,digitalSignature,keyCertSign,cRLSign", - crypto.NID_subject_key_identifier: "hash", - crypto.NID_authority_key_identifier: "keyid:always,issuer", + crypto.NidBasicConstraints: "critical,CA:TRUE", + crypto.NidKeyUsage: "critical,digitalSignature,keyCertSign,cRLSign", + crypto.NidSubjectKeyIdentifier: "hash", + crypto.NidAuthorityKeyIdentifier: "keyid:always,issuer", } ca := createCertificate(caInfo, caKey, caExtensions) caFile := filepath.Join(genPath, "chain-ca.crt") @@ -102,8 +111,8 @@ func main() { CommonName: "localhost", } extensions := map[crypto.NID]string{ - crypto.NID_basic_constraints: "critical,CA:FALSE", - crypto.NID_key_usage: info.keyUsage, + crypto.NidBasicConstraints: "critical,CA:FALSE", + crypto.NidKeyUsage: info.keyUsage, } cert := createCertificate(certInfo, key, extensions) @@ -111,6 +120,7 @@ func main() { if err != nil { panic(err) } + certFile := filepath.Join(genPath, info.name+".crt") signAndSaveCert(cert, caKey, certFile) } diff --git a/examples/sm2_encrypt/main.go b/examples/sm2_encrypt/main.go index a2f8941..defda17 100644 --- a/examples/sm2_encrypt/main.go +++ b/examples/sm2_encrypt/main.go @@ -3,6 +3,7 @@ package main import ( "encoding/hex" "fmt" + "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sm2" ) @@ -29,6 +30,7 @@ func main() { if err != nil { panic(err) } + fmt.Printf("SM2(%s)=%s\n", data, hex.EncodeToString(ciphertext)) // Decrypt ciphertext diff --git a/examples/sm2_keygen/main.go b/examples/sm2_keygen/main.go index 5bcb864..cd98aa9 100644 --- a/examples/sm2_keygen/main.go +++ b/examples/sm2_keygen/main.go @@ -23,6 +23,7 @@ func main() { if err != nil { panic(err) } + fmt.Printf("Private Key:\n%s\n", pem) pub := priv.Public() diff --git a/examples/sm2_sign/main.go b/examples/sm2_sign/main.go index c68692e..839ef87 100644 --- a/examples/sm2_sign/main.go +++ b/examples/sm2_sign/main.go @@ -35,6 +35,7 @@ func main() { if err != nil { panic(err) } + fmt.Printf("SM2withSM3(%s)=(r, s)=(%s, %s)\n", data, hex.EncodeToString(r.Bytes()), hex.EncodeToString(s.Bytes())) pub := priv.Public() diff --git a/examples/sm2_signasn1/main.go b/examples/sm2_signasn1/main.go index 4bbca21..96aa242 100644 --- a/examples/sm2_signasn1/main.go +++ b/examples/sm2_signasn1/main.go @@ -24,6 +24,7 @@ Wt3Te/d/8Mr57Tf886i09VwDhSMmH8pmNq/mp6+ioUgqYG9cs6GLLioe func main() { data := []byte("hello world") + priv, err := crypto.LoadPrivateKeyFromPEM(sm2_key1) if err != nil { panic(err) @@ -34,6 +35,7 @@ func main() { if err != nil { panic(err) } + fmt.Printf("SM2withSM3(%s)=%s\n", data, hex.EncodeToString(signature)) pub := priv.Public() diff --git a/examples/sm3/main.go b/examples/sm3/main.go index 2cd060c..a07ea38 100644 --- a/examples/sm3/main.go +++ b/examples/sm3/main.go @@ -26,11 +26,12 @@ func main() { if _, err := h.Write([]byte("hello")); err != nil { log.Fatal(err) } + if _, err := h.Write([]byte(" world")); err != nil { log.Fatal(err) } - var res [sm3.SM3_DIGEST_LENGTH]byte + var res [sm3.MDSize]byte fmt.Printf("SM3(%s)=%x\n", msg, h.Sum(res[:0])) } diff --git a/examples/sm4/main.go b/examples/sm4/main.go index a16c830..d7b5c54 100644 --- a/examples/sm4/main.go +++ b/examples/sm4/main.go @@ -23,7 +23,7 @@ func sm4CBCEncrypt() { plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - enc, err := sm4.NewSM4Encrypter(crypto.CIPHER_MODE_CBC, key, iv) + enc, err := sm4.NewEncrypter(crypto.CipherModeCBC, key, iv) if err != nil { log.Fatal("failed to create encrypter: ", err) } @@ -52,7 +52,7 @@ func sm4CBCDecrypt() { plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - enc, err := sm4.NewSM4Decrypter(crypto.CIPHER_MODE_CBC, key, iv) + enc, err := sm4.NewDecrypter(crypto.CipherModeCBC, key, iv) if err != nil { log.Fatal("failed to create decrypter: ", err) } @@ -83,7 +83,7 @@ func sm4GCMEncrypt() { plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") - enc, err := sm4.NewSM4Encrypter(crypto.CIPHER_MODE_GCM, key, iv) + enc, err := sm4.NewEncrypter(crypto.CipherModeGCM, key, iv) if err != nil { log.Fatal("failed to create encrypter: ", err) } @@ -125,7 +125,7 @@ func sm4GCMDecrypt() { plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") - dec, err := sm4.NewSM4Decrypter(crypto.CIPHER_MODE_GCM, key, iv) + dec, err := sm4.NewDecrypter(crypto.CipherModeGCM, key, iv) if err != nil { log.Fatal("failed to create decrypter: ", err) } diff --git a/examples/tlcp_client/main.go b/examples/tlcp_client/main.go index 169829e..eec9bf5 100644 --- a/examples/tlcp_client/main.go +++ b/examples/tlcp_client/main.go @@ -46,6 +46,7 @@ func main() { flag.Parse() var version ts.SSLVersion + switch tlsVersion { case "TLSv1.3": version = ts.TLSv1_3 @@ -60,6 +61,7 @@ func main() { default: version = ts.NTLS } + ctx, err := ts.NewCtxWithVersion(version) if err != nil { panic("NewCtxWithVersion failed: " + err.Error()) @@ -77,11 +79,13 @@ func main() { if err := ctx.SetCipherList(cipherSuites); err != nil { panic(err) } + if signCertFile != "" { signCertPEM, err := os.ReadFile(signCertFile) if err != nil { panic(err) } + signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) if err != nil { panic(err) @@ -97,6 +101,7 @@ func main() { if err != nil { panic(err) } + signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) if err != nil { panic(err) @@ -112,6 +117,7 @@ func main() { if err != nil { panic(err) } + encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) if err != nil { panic(err) @@ -176,11 +182,13 @@ func main() { request := text + "\n" fmt.Println(">>>\n" + request) + if _, err := conn.Write([]byte(request)); err != nil { panic(err) } buffer := make([]byte, 4096) + n, err := conn.Read(buffer) if err != nil { fmt.Println("read error:", err) diff --git a/examples/tlcp_server/main.go b/examples/tlcp_server/main.go index 197e875..3cca632 100644 --- a/examples/tlcp_server/main.go +++ b/examples/tlcp_server/main.go @@ -59,7 +59,6 @@ func handleConn(conn net.Conn) { defer func(conn net.Conn) { err := conn.Close() if err != nil { - } }(conn) @@ -71,6 +70,7 @@ func handleConn(conn net.Conn) { } ntls := conn.(*ts.Conn) + ver, err := ntls.GetVersion() if err != nil { log.Println("failed get version: ", err) @@ -98,6 +98,7 @@ func handleConn(conn net.Conn) { func newTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string, alpnProtocols []string, tlsVersion string) (net.Listener, error) { var version ts.SSLVersion + switch tlsVersion { case "TLSv1.3": version = ts.TLSv1_3 @@ -112,6 +113,7 @@ func newTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCert default: version = ts.TLSv1_3 } + ctx, err := ts.NewCtxWithVersion(version) if err != nil { log.Println(err) @@ -216,11 +218,13 @@ func loadCertAndKeyForSSL(ssl *ts.SSL, certKeyPair crypto.GMDoubleCertKey) error log.Println(err) return err } + encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) if err != nil { log.Println(err) return err } + err = ctx.UseEncryptCertificate(encCert) if err != nil { return err @@ -231,11 +235,13 @@ func loadCertAndKeyForSSL(ssl *ts.SSL, certKeyPair crypto.GMDoubleCertKey) error log.Println(err) return err } + signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) if err != nil { log.Println(err) return err } + err = ctx.UseSignCertificate(signCert) if err != nil { return err @@ -246,11 +252,13 @@ func loadCertAndKeyForSSL(ssl *ts.SSL, certKeyPair crypto.GMDoubleCertKey) error log.Println(err) return err } + encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) if err != nil { log.Println(err) return err } + err = ctx.UseEncryptPrivateKey(encKey) if err != nil { return err @@ -261,11 +269,13 @@ func loadCertAndKeyForSSL(ssl *ts.SSL, certKeyPair crypto.GMDoubleCertKey) error log.Println(err) return err } + signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) if err != nil { log.Println(err) return err } + err = ctx.UseSignPrivateKey(signKey) if err != nil { return err @@ -283,11 +293,13 @@ func loadCertAndKey(ctx *ts.Ctx, pair crypto.GMDoubleCertKey) (err error) { log.Println(err) return err } + encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) if err != nil { log.Println(err) return err } + err = ctx.UseEncryptCertificate(encCert) if err != nil { return err @@ -298,11 +310,13 @@ func loadCertAndKey(ctx *ts.Ctx, pair crypto.GMDoubleCertKey) (err error) { log.Println(err) return err } + signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) if err != nil { log.Println(err) return err } + err = ctx.UseSignCertificate(signCert) if err != nil { return err @@ -313,11 +327,13 @@ func loadCertAndKey(ctx *ts.Ctx, pair crypto.GMDoubleCertKey) (err error) { log.Println(err) return err } + encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) if err != nil { log.Println(err) return err } + err = ctx.UseEncryptPrivateKey(encKey) if err != nil { return err @@ -328,18 +344,19 @@ func loadCertAndKey(ctx *ts.Ctx, pair crypto.GMDoubleCertKey) (err error) { log.Println(err) return err } + signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) if err != nil { log.Println(err) return err } + err = ctx.UseSignPrivateKey(signKey) if err != nil { return err } return nil - } func main() { diff --git a/http.go b/http.go index 37239bd..1bbdf95 100644 --- a/http.go +++ b/http.go @@ -15,27 +15,34 @@ package tongsuogo import ( + "fmt" "net/http" + "time" ) +const defaultReadHeaderTimeout = 120 + // ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over // the given tcp address, configured to use the provided cert and key files. -func ListenAndServeTLS(addr string, cert_file string, key_file string, - handler http.Handler) error { +func ListenAndServeTLS(addr string, certFile string, keyFile string, + handler http.Handler, +) error { return ServerListenAndServeTLS( - &http.Server{Addr: addr, Handler: handler}, cert_file, key_file) + &http.Server{Addr: addr, Handler: handler, ReadHeaderTimeout: defaultReadHeaderTimeout * time.Second}, + certFile, keyFile) } // ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL // configured to use the provided cert and key files. func ServerListenAndServeTLS(srv *http.Server, - cert_file, key_file string) error { + certFile, keyFile string, +) error { addr := srv.Addr if addr == "" { addr = ":https" } - ctx, err := NewCtxFromFiles(cert_file, key_file) + ctx, err := NewCtxFromFiles(certFile, keyFile) if err != nil { return err } @@ -45,7 +52,12 @@ func ServerListenAndServeTLS(srv *http.Server, return err } - return srv.Serve(l) + err = srv.Serve(l) + if err != nil { + return fmt.Errorf("failed to serve tls: %w", err) + } + + return nil } // TODO: http client integration diff --git a/init.go b/init.go index 9df5977..fe1442d 100644 --- a/init.go +++ b/init.go @@ -24,12 +24,6 @@ package tongsuogo // #include "shim.h" import "C" -import ( - "fmt" -) - func init() { - if rc := C.X_tongsuogo_init(); rc != 0 { - panic(fmt.Errorf("X_tongsuogo_init failed with %d", rc)) - } + C.X_tongsuogo_init() } diff --git a/net.go b/net.go index 362ca0d..706920e 100644 --- a/net.go +++ b/net.go @@ -16,25 +16,31 @@ package tongsuogo import ( "errors" + "fmt" "net" ) +var ErrNilParam = errors.New("nil parameter") + type listener struct { net.Listener ctx *Ctx } -func (l *listener) Accept() (c net.Conn, err error) { - c, err = l.Listener.Accept() +func (l *listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to accept: %w", err) } - ssl_c, err := Server(c, l.ctx) + + server, err := Server(conn, l.ctx) if err != nil { - c.Close() + conn.Close() + return nil, err } - return ssl_c, nil + + return server, nil } // NewListener wraps an existing net.Listener such that all accepted @@ -51,12 +57,14 @@ func NewListener(inner net.Listener, ctx *Ctx) net.Listener { // an OpenSSL server connection using the provided context ctx. func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) { if ctx == nil { - return nil, errors.New("no ssl context provided") + return nil, fmt.Errorf("no ssl context provided: %w", ErrNilParam) } + l, err := net.Listen(network, laddr) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to listen: %w", err) } + return NewListener(l, ctx), nil } @@ -95,57 +103,71 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags, host string) (*Conn, // If session is not nil it will be used to resume the tls state. The session // can be retrieved from the GetSession method on the Conn. func DialSession(network, addr string, ctx *Ctx, flags DialFlags, - session []byte, host string) (*Conn, error) { - + session []byte, host string, +) (*Conn, error) { var err error if host == "" { host, _, err = net.SplitHostPort(addr) } + if err != nil { - return nil, err + return nil, fmt.Errorf("failed to split host and port: %w", err) } + if ctx == nil { var err error + ctx, err = NewCtx() if err != nil { return nil, err } - // TODO: use operating system default certificate chain? } - c, err := net.Dial(network, addr) + + conn, err := net.Dial(network, addr) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to connect: %w", err) } - conn, err := Client(c, ctx) + + client, err := Client(conn, ctx) if err != nil { - c.Close() + conn.Close() + return nil, err } + if session != nil { - err := conn.setSession(session) + err := client.setSession(session) if err != nil { - c.Close() + conn.Close() + return nil, err } } + if flags&DisableSNI == 0 { - err = conn.SetTlsExtHostName(host) + err = client.SetTLSExtHostName(host) if err != nil { - conn.Close() - return nil, err + client.Close() + + return nil, fmt.Errorf("failed to set TLS host name: %w", err) } } - err = conn.Handshake() + + err = client.Handshake() if err != nil { - conn.Close() - return nil, err + client.Close() + + return nil, fmt.Errorf("failed to handshake: %w", err) } + if flags&InsecureSkipHostVerification == 0 { - err = conn.VerifyHostname(host) + err = client.VerifyHostname(host) if err != nil { - conn.Close() - return nil, err + client.Close() + + return nil, fmt.Errorf("failed to verify host name: %w", err) } } - return conn, nil + + return client, nil } diff --git a/ntls_test.go b/ntls_test.go index 407f269..b9f3a66 100644 --- a/ntls_test.go +++ b/ntls_test.go @@ -1,4 +1,10 @@ -package tongsuogo +// Copyright 2024 The Tongsuo Project Authors. All Rights Reserved. +// +// Licensed under the Apache License 2.0 (the "License"). You may not use +// this file except in compliance with the License. You can obtain a copy +// in the file LICENSE in the source distribution or at +// https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE +package tongsuogo_test import ( "bufio" @@ -12,6 +18,7 @@ import ( "testing" "time" + ts "github.com/tongsuo-project/tongsuo-go-sdk" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) @@ -24,43 +31,54 @@ const ( enableSNI = true testCertDir = "test/certs/sm2" + testCaFile = "test/certs/sm2/chain-ca.crt" + testRequest = "hello tongsuo\n" ) +func generateSM2KeyAndSave(t *testing.T, filename string) crypto.PrivateKey { + t.Helper() + + key, err := crypto.GenerateECKey(crypto.SM2Curve) + if err != nil { + t.Fatal(err) + } + + pem, err := key.MarshalPKCS8PrivateKeyPEM() + if err != nil { + t.Fatal(err) + } + + err = crypto.SavePEMToFile(pem, filename) + if err != nil { + t.Fatal(err) + } + + return key +} + func TestCAGenerateSM2AndNTLS(t *testing.T) { + t.Parallel() // Create a temporary directory to store generated keys and certificates tmpDir, err := os.MkdirTemp("", "tongsuo-test-*") if err != nil { t.Fatalf("failed to create temporary directory: %v", err) } - defer os.RemoveAll(tmpDir) - // Helper function: generate and save key - generateAndSaveKey := func(filename string) crypto.PrivateKey { - key, err := crypto.GenerateECKey(crypto.Sm2Curve) - if err != nil { - t.Fatal(err) - } - pem, err := key.MarshalPKCS8PrivateKeyPEM() - if err != nil { - t.Fatal(err) - } - err = crypto.SavePEMToFile(pem, filename) - if err != nil { - t.Fatal(err) - } - return key - } + t.Cleanup(func() { + os.RemoveAll(tmpDir) + }) - // Helper function: sign and save certificate signAndSaveCert := func(cert *crypto.Certificate, caKey crypto.PrivateKey, filename string) { - err := cert.Sign(caKey, crypto.EVP_SM3) + err := cert.Sign(caKey, crypto.MDSM3) if err != nil { t.Fatal(err) } + certPem, err := cert.MarshalPEM() if err != nil { t.Fatal(err) } + err = crypto.SavePEMToFile(certPem, filename) if err != nil { t.Fatal(err) @@ -68,34 +86,38 @@ func TestCAGenerateSM2AndNTLS(t *testing.T) { } // Create CA certificate - caKey, err := crypto.GenerateECKey(crypto.Sm2Curve) + caKey, err := crypto.GenerateECKey(crypto.SM2Curve) if err != nil { t.Fatal(err) } + caInfo := crypto.CertificateInfo{ Serial: big.NewInt(1), + Issued: 0, Expires: 87600 * time.Hour, // 10 years Country: "US", Organization: "Test CA", CommonName: "CA", } caExtensions := map[crypto.NID]string{ - crypto.NID_basic_constraints: "critical,CA:TRUE", - crypto.NID_key_usage: "critical,digitalSignature,keyCertSign,cRLSign", - crypto.NID_subject_key_identifier: "hash", - crypto.NID_authority_key_identifier: "keyid:always,issuer", + crypto.NidBasicConstraints: "critical,CA:TRUE", + crypto.NidKeyUsage: "critical,digitalSignature,keyCertSign,cRLSign", + crypto.NidSubjectKeyIdentifier: "hash", + crypto.NidAuthorityKeyIdentifier: "keyid:always,issuer", } - ca, err := crypto.NewCertificate(&caInfo, caKey) + + caCert, err := crypto.NewCertificate(&caInfo, caKey) if err != nil { t.Fatal(err) } - err = ca.AddExtensions(caExtensions) + + err = caCert.AddExtensions(caExtensions) if err != nil { t.Fatal(err) } // Save CA certificate to tmpDir caCertFile := filepath.Join(tmpDir, "chain-ca.crt") - signAndSaveCert(ca, caKey, caCertFile) + signAndSaveCert(caCert, caKey, caCertFile) // Define additional certificate information certInfos := []struct { @@ -110,8 +132,8 @@ func TestCAGenerateSM2AndNTLS(t *testing.T) { // Create additional certificates for _, info := range certInfos { - keyFile := filepath.Join(tmpDir, fmt.Sprintf("%s.key", info.name)) - key := generateAndSaveKey(keyFile) + keyFile := filepath.Join(tmpDir, info.name+".key") + key := generateSM2KeyAndSave(t, keyFile) certInfo := crypto.CertificateInfo{ Serial: big.NewInt(1), Issued: 0, @@ -121,31 +143,37 @@ func TestCAGenerateSM2AndNTLS(t *testing.T) { CommonName: "localhost", } extensions := map[crypto.NID]string{ - crypto.NID_basic_constraints: "critical,CA:FALSE", - crypto.NID_key_usage: info.keyUsage, + crypto.NidBasicConstraints: "critical,CA:FALSE", + crypto.NidKeyUsage: info.keyUsage, } + cert, err := crypto.NewCertificate(&certInfo, key) if err != nil { t.Fatal(err) } + err = cert.AddExtensions(extensions) if err != nil { t.Fatal(err) } - err = cert.SetIssuer(ca) + + err = cert.SetIssuer(caCert) if err != nil { t.Fatal(err) } - certFile := filepath.Join(tmpDir, fmt.Sprintf("%s.crt", info.name)) + + certFile := filepath.Join(tmpDir, info.name+".crt") signAndSaveCert(cert, caKey, certFile) } t.Run("NTLS Test", func(t *testing.T) { + t.Parallel() testNTLS(t, tmpDir) }) } func testNTLS(t *testing.T, tmpDir string) { + t.Helper() // Use the generated keys and certificates from tmpDir to test NTLS cases := []struct { cipher string @@ -157,9 +185,13 @@ func testNTLS(t *testing.T, tmpDir string) { runServer bool }{ { - cipher: ECCSM2Cipher, - runServer: internalServer, - caFile: filepath.Join(tmpDir, "chain-ca.crt"), + cipher: ECCSM2Cipher, + signCertFile: "", + signKeyFile: "", + encCertFile: "", + encKeyFile: "", + caFile: filepath.Join(tmpDir, "chain-ca.crt"), + runServer: internalServer, }, { cipher: ECDHESM2Cipher, @@ -172,113 +204,48 @@ func testNTLS(t *testing.T, tmpDir string) { }, } - for _, c := range cases { - t.Run(c.cipher, func(t *testing.T) { - if c.runServer { - server, err := newNTLSServer(t, tmpDir, func(sslctx *Ctx) error { - return sslctx.SetCipherList(c.cipher) - }) - - if err != nil { - t.Error(err) - return - } - defer server.Close() - go server.Run() - } - - ctx, err := NewCtxWithVersion(NTLS) + for _, item := range cases { + t.Run(item.cipher, func(t *testing.T) { + server, err := newNTLSServer(t, tmpDir, func(sslctx *ts.Ctx) error { + return sslctx.SetCipherList(item.cipher) + }) if err != nil { t.Error(err) - return - } - if err := ctx.SetCipherList(c.cipher); err != nil { - t.Error(err) return } - if c.signCertFile != "" { - signCertPEM, err := os.ReadFile(c.signCertFile) - if err != nil { - t.Error(err) - return - } - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { - t.Error(err) - return - } - - if err := ctx.UseSignCertificate(signCert); err != nil { - t.Error(err) - return - } - } + defer server.Close() + go server.Run() - if c.signKeyFile != "" { - signKeyPEM, err := os.ReadFile(c.signKeyFile) - if err != nil { - t.Error(err) - return - } - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) - if err != nil { - t.Error(err) - return - } + ctx, err := ts.NewCtxWithVersion(ts.NTLS) + if err != nil { + t.Error(err) - if err := ctx.UseSignPrivateKey(signKey); err != nil { - t.Error(err) - return - } + return } - if c.encCertFile != "" { - encCertPEM, err := os.ReadFile(c.encCertFile) - if err != nil { - t.Error(err) - return - } - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - t.Error(err) - return - } + if err := ctx.SetCipherList(item.cipher); err != nil { + t.Error(err) - if err := ctx.UseEncryptCertificate(encCert); err != nil { - t.Error(err) - return - } + return } - if c.encKeyFile != "" { - encKeyPEM, err := os.ReadFile(c.encKeyFile) - if err != nil { - t.Error(err) - return - } - - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) - if err != nil { - t.Error(err) - return - } - - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - t.Error(err) - return - } + err = ctxSetGMDoubleCertKey(ctx, item.signCertFile, item.signKeyFile, item.encCertFile, item.encKeyFile) + if err != nil { + t.Error(err) + return } - if c.caFile != "" { - if err := ctx.LoadVerifyLocations(c.caFile, ""); err != nil { + if item.caFile != "" { + if err := ctx.LoadVerifyLocations(item.caFile, ""); err != nil { t.Error(err) return } } - conn, err := DialSession("tcp", "127.0.0.1:4433", ctx, InsecureSkipHostVerification, nil, "") + conn, err := ts.DialSession(server.Addr().Network(), server.Addr().String(), ctx, + ts.InsecureSkipHostVerification, nil, "") if err != nil { t.Error(err) return @@ -293,8 +260,7 @@ func testNTLS(t *testing.T, tmpDir string) { t.Log("current cipher", cipher) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -305,7 +271,7 @@ func testNTLS(t *testing.T, tmpDir string) { return } - if resp != request { + if resp != testRequest { t.Error("response data is not expected: ", resp) return } @@ -314,6 +280,8 @@ func testNTLS(t *testing.T, tmpDir string) { } func TestNTLS(t *testing.T) { + t.Parallel() + cases := []struct { cipher string signCertFile string @@ -321,12 +289,14 @@ func TestNTLS(t *testing.T) { encCertFile string encKeyFile string caFile string - runServer bool }{ { - cipher: ECCSM2Cipher, - runServer: internalServer, - caFile: filepath.Join(testCertDir, "chain-ca.crt"), + cipher: ECCSM2Cipher, + signCertFile: "", + signKeyFile: "", + encCertFile: "", + encKeyFile: "", + caFile: filepath.Join(testCertDir, "chain-ca.crt"), }, { cipher: ECDHESM2Cipher, @@ -335,117 +305,52 @@ func TestNTLS(t *testing.T) { encCertFile: filepath.Join(testCertDir, "client_enc.crt"), encKeyFile: filepath.Join(testCertDir, "client_enc.key"), caFile: filepath.Join(testCertDir, "chain-ca.crt"), - runServer: internalServer, }, } - for _, c := range cases { - t.Run(c.cipher, func(t *testing.T) { - if c.runServer { - server, err := newNTLSServer(t, testCertDir, func(sslctx *Ctx) error { - return sslctx.SetCipherList(c.cipher) - }) + for _, item := range cases { + item := item + t.Run(item.cipher, func(t *testing.T) { + t.Parallel() - if err != nil { - t.Error(err) - return - } - defer server.Close() - go server.Run() - } - - ctx, err := NewCtxWithVersion(NTLS) + server, err := newNTLSServer(t, testCertDir, func(sslctx *ts.Ctx) error { + return sslctx.SetCipherList(item.cipher) + }) if err != nil { t.Error(err) return } - if err := ctx.SetCipherList(c.cipher); err != nil { - t.Error(err) - return - } - - if c.signCertFile != "" { - signCertPEM, err := os.ReadFile(c.signCertFile) - if err != nil { - t.Error(err) - return - } - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { - t.Error(err) - return - } - - if err := ctx.UseSignCertificate(signCert); err != nil { - t.Error(err) - return - } - } + defer server.Close() - if c.signKeyFile != "" { - signKeyPEM, err := os.ReadFile(c.signKeyFile) - if err != nil { - t.Error(err) - return - } - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) - if err != nil { - t.Error(err) - return - } + go server.Run() - if err := ctx.UseSignPrivateKey(signKey); err != nil { - t.Error(err) - return - } + ctx, err := ts.NewCtxWithVersion(ts.NTLS) + if err != nil { + t.Error(err) + return } - if c.encCertFile != "" { - encCertPEM, err := os.ReadFile(c.encCertFile) - if err != nil { - t.Error(err) - return - } - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - t.Error(err) - return - } - - if err := ctx.UseEncryptCertificate(encCert); err != nil { - t.Error(err) - return - } + if err := ctx.SetCipherList(item.cipher); err != nil { + t.Error(err) + return } - if c.encKeyFile != "" { - encKeyPEM, err := os.ReadFile(c.encKeyFile) - if err != nil { - t.Error(err) - return - } - - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) - if err != nil { - t.Error(err) - return - } - - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - t.Error(err) - return - } + err = ctxSetGMDoubleCertKey(ctx, item.signCertFile, item.signKeyFile, item.encCertFile, item.encKeyFile) + if err != nil { + t.Error(err) + return } - if c.caFile != "" { - if err := ctx.LoadVerifyLocations(c.caFile, ""); err != nil { + if item.caFile != "" { + if err := ctx.LoadVerifyLocations(item.caFile, ""); err != nil { t.Error(err) return } } - conn, err := DialSession("tcp", "127.0.0.1:4433", ctx, InsecureSkipHostVerification, nil, "") + conn, err := ts.DialSession(server.Addr().Network(), server.Addr().String(), ctx, + ts.InsecureSkipHostVerification, nil, "") if err != nil { t.Error(err) return @@ -460,8 +365,7 @@ func TestNTLS(t *testing.T) { t.Log("current cipher", cipher) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -472,7 +376,7 @@ func TestNTLS(t *testing.T) { return } - if resp != request { + if resp != testRequest { t.Error("response data is not expected: ", resp) return } @@ -480,8 +384,10 @@ func TestNTLS(t *testing.T) { } } -func newNTLSServer(t *testing.T, testDir string, options ...func(sslctx *Ctx) error) (*echoServer, error) { - ctx, err := NewCtxWithVersion(NTLS) +func newNTLSServer(t *testing.T, testDir string, options ...func(sslctx *ts.Ctx) error) (*echoServer, error) { + t.Helper() + + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return nil, err @@ -499,75 +405,15 @@ func newNTLSServer(t *testing.T, testDir string, options ...func(sslctx *Ctx) er return nil, err } - encCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - signCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseEncryptCertificate(encCert); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignCertificate(signCert); err != nil { - t.Error(err) - return nil, err - } - - encKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.key")) - if err != nil { - t.Error(err) - return nil, err - } - - signKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.key")) - if err != nil { - t.Error(err) - return nil, err - } - - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) + err = ctxSetGMDoubleCertKey(ctx, filepath.Join(testDir, "server_sign.crt"), + filepath.Join(testDir, "server_sign.key"), filepath.Join(testDir, "server_enc.crt"), + filepath.Join(testDir, "server_enc.key")) if err != nil { t.Error(err) return nil, err } - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignPrivateKey(signKey); err != nil { - t.Error(err) - return nil, err - } - - lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + lis, err := ts.Listen("tcp", "localhost:0", ctx) if err != nil { t.Error(err) return nil, err @@ -581,15 +427,21 @@ type echoServer struct { } func (s *echoServer) Close() error { - return s.Listener.Close() + err := s.Listener.Close() + if err != nil { + return fmt.Errorf("failed to close listener: %w", err) + } + + return nil } func (s *echoServer) Run() error { for { conn, err := s.Listener.Accept() if err != nil { - return err + return fmt.Errorf("failed to accept: %w", err) } + go handleConn(conn) } } @@ -598,8 +450,9 @@ func (s *echoServer) RunForALPN() error { for { conn, err := s.Listener.Accept() if err != nil { - return err + return fmt.Errorf("failed to accept: %w", err) } + go handleConnForALPN(conn) } } @@ -610,13 +463,13 @@ func handleConn(conn net.Conn) { // Read incoming data into buffer req, err := bufio.NewReader(conn).ReadString('\n') if err != nil { - log.Printf("Error reading incoming data: %v", err) + log.Printf("Error reading incoming data: %s", err) return } // Send a response back to the client if _, err = conn.Write([]byte(req + "\n")); err != nil { - log.Printf("Unable to send response: %v", err) + log.Printf("Unable to send response: %s", err) return } } @@ -627,20 +480,25 @@ func handleConnForALPN(conn net.Conn) { // Read incoming data into buffer req, err := bufio.NewReader(conn).ReadString('\n') if err != nil { - log.Printf("Error reading incoming data: %v", err) + log.Printf("Error reading incoming data: %s", err) return } // Send a response back to the client if _, err = conn.Write([]byte(req + "\n")); err != nil { - log.Printf("Unable to send response: %v", err) + log.Printf("Unable to send response: %s", err) + return + } + + ntls, ok := conn.(*ts.Conn) + if !ok { + log.Printf("Connection is not an NTLS connection") return } - ntls := conn.(*Conn) protocol, err := ntls.GetALPNNegotiated() if err != nil { - log.Printf("Error getting negotiated protocol: %v", err) + log.Printf("Error getting negotiated protocol: %s", err) return } @@ -648,6 +506,7 @@ func handleConnForALPN(conn net.Conn) { } func TestSNI(t *testing.T) { + t.Parallel() // Run server certFiles, err := ReadCertificateFiles("test/sni_certs") if err != nil { @@ -655,10 +514,9 @@ func TestSNI(t *testing.T) { return } - server, err := newNTLSServerWithSNI(t, testCertDir, certFiles, enableSNI, func(sslctx *Ctx) error { + server, err := newNTLSServerWithSNI(t, testCertDir, certFiles, enableSNI, func(sslctx *ts.Ctx) error { return sslctx.SetCipherList("ECC-SM2-SM4-CBC-SM3") }) - if err != nil { t.Error(err) return @@ -668,10 +526,7 @@ func TestSNI(t *testing.T) { go server.Run() // Run Client - caFile := "test/certs/sm2/chain-ca.crt" - connAddr := "127.0.0.1:4433" - - ctx, err := NewCtxWithVersion(NTLS) + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return @@ -682,7 +537,7 @@ func TestSNI(t *testing.T) { return } - if err := ctx.LoadVerifyLocations(caFile, ""); err != nil { + if err := ctx.LoadVerifyLocations(testCaFile, ""); err != nil { t.Error(err) return } @@ -691,7 +546,8 @@ func TestSNI(t *testing.T) { serverName := "default" // Connect to the server - conn, err := DialSession("tcp", connAddr, ctx, InsecureSkipHostVerification, nil, serverName) + conn, err := ts.DialSession(server.Addr().Network(), server.Addr().String(), ctx, + ts.InsecureSkipHostVerification, nil, serverName) if err != nil { t.Error(err) return @@ -706,8 +562,7 @@ func TestSNI(t *testing.T) { t.Log("current cipher", cipher) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -718,14 +573,86 @@ func TestSNI(t *testing.T) { return } - if resp != request { - t.Error("response data is not expected: ", resp) - return + if resp != testRequest { + t.Error("response data is not expected: ", resp) + return + } +} + +func ctxSetGMDoubleCertKey(ctx *ts.Ctx, signCertFile, signKeyFile, encCertFile, encKeyFile string) error { + if signCertFile != "" { + signCertPEM, err := os.ReadFile(signCertFile) + if err != nil { + return fmt.Errorf("failed to read sign cert file: %w", err) + } + + signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) + if err != nil { + return fmt.Errorf("failed to load sign cert: %w", err) + } + + if err := ctx.UseSignCertificate(signCert); err != nil { + return fmt.Errorf("failed to set sign cert: %w", err) + } + } + + if signKeyFile != "" { + signKeyPEM, err := os.ReadFile(signKeyFile) + if err != nil { + return fmt.Errorf("failed to read sign key file: %w", err) + } + + signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) + if err != nil { + return fmt.Errorf("failed to load sign key: %w", err) + } + + if err := ctx.UseSignPrivateKey(signKey); err != nil { + return fmt.Errorf("failed to set sign key: %w", err) + } + } + + if encCertFile != "" { + encCertPEM, err := os.ReadFile(encCertFile) + if err != nil { + return fmt.Errorf("failed to read enc cert file: %w", err) + } + + encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) + if err != nil { + return fmt.Errorf("failed to load enc cert: %w", err) + } + + if err := ctx.UseEncryptCertificate(encCert); err != nil { + return fmt.Errorf("failed to set enc cert: %w", err) + } + } + + if encKeyFile != "" { + encKeyPEM, err := os.ReadFile(encKeyFile) + if err != nil { + return fmt.Errorf("failed to read enc key file: %w", err) + } + + encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) + if err != nil { + return fmt.Errorf("failed to load enc key: %w", err) + } + + if err := ctx.UseEncryptPrivateKey(encKey); err != nil { + return fmt.Errorf("failed to set enc key: %w", err) + } } + + return nil } -func newNTLSServerWithSNI(t *testing.T, testDir string, certKeyPairs map[string]crypto.GMDoubleCertKey, sni bool, options ...func(sslctx *Ctx) error) (*echoServer, error) { - ctx, err := NewCtxWithVersion(NTLS) +func newNTLSServerWithSNI(t *testing.T, testDir string, certKeyPairs map[string]crypto.GMDoubleCertKey, sni bool, + options ...func(sslctx *ts.Ctx) error, +) (*echoServer, error) { + t.Helper() + + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return nil, err @@ -745,94 +672,33 @@ func newNTLSServerWithSNI(t *testing.T, testDir string, certKeyPairs map[string] // Set SNI callback if sni == true { - ctx.SetTLSExtServernameCallback(func(ssl *SSL) SSLTLSExtErr { + ctx.SetTLSExtServernameCallback(func(ssl *ts.SSL) ts.SSLTLSExtErr { serverName := ssl.GetServername() log.Printf("SNI: Client requested hostname: %s\n", serverName) if certKeyPair, ok := certKeyPairs[serverName]; ok { if err := loadCertAndKeyForSSL(ssl, certKeyPair); err != nil { log.Printf("Error loading certificate for %s: %v\n", serverName, err) - return SSLTLSExtErrAlertFatal + return ts.SSLTLSExtErrAlertFatal } } else { log.Printf("No certificate found for %s, using default\n", serverName) - return SSLTLSExtErrNoAck + return ts.SSLTLSExtErrNoAck } - return SSLTLSExtErrOK + return ts.SSLTLSExtErrOK }) } - // Load a default certificate and key - encCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - signCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseEncryptCertificate(encCert); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignCertificate(signCert); err != nil { - t.Error(err) - return nil, err - } - - encKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.key")) - if err != nil { - t.Error(err) - return nil, err - } - - signKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.key")) - if err != nil { - t.Error(err) - return nil, err - } - - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) + err = ctxSetGMDoubleCertKey(ctx, filepath.Join(testDir, "server_sign.crt"), + filepath.Join(testDir, "server_sign.key"), filepath.Join(testDir, "server_enc.crt"), + filepath.Join(testDir, "server_enc.key")) if err != nil { t.Error(err) return nil, err } - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignPrivateKey(signKey); err != nil { - t.Error(err) - return nil, err - } - - lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + lis, err := ts.Listen("tcp", "localhost:0", ctx) if err != nil { t.Error(err) return nil, err @@ -841,71 +707,79 @@ func newNTLSServerWithSNI(t *testing.T, testDir string, certKeyPairs map[string] return &echoServer{lis}, nil } -// Load certificate and key for SSL -func loadCertAndKeyForSSL(ssl *SSL, certKeyPair crypto.GMDoubleCertKey) error { - ctx, err := NewCtx() +// Load certificate and key for SSL. +func loadCertAndKeyForSSL(ssl *ts.SSL, certKeyPair crypto.GMDoubleCertKey) error { + ctx, err := ts.NewCtx() if err != nil { - return err + return fmt.Errorf("failed to create ctx: %w", err) } encCertPEM, err := crypto.LoadPEMFromFile(certKeyPair.EncCertFile) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to load certificate from file: %w", err) } + encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to load enc cert: %w", err) } + err = ctx.UseEncryptCertificate(encCert) if err != nil { - return err + return fmt.Errorf("failed to set enc cert: %w", err) } signCertPEM, err := crypto.LoadPEMFromFile(certKeyPair.SignCertFile) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to load sign cert from file: %w", err) } + signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to load sign cert: %w", err) } + err = ctx.UseSignCertificate(signCert) if err != nil { - return err + return fmt.Errorf("failed to set sign cert: %w", err) } encKeyPEM, err := os.ReadFile(certKeyPair.EncKeyFile) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to read enc key file: %w", err) } + encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to load enc key: %w", err) } + err = ctx.UseEncryptPrivateKey(encKey) if err != nil { - return err + return fmt.Errorf("failed to set enc key: %w", err) } signKeyPEM, err := os.ReadFile(certKeyPair.SignKeyFile) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to read sign key file: %w", err) } + signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) if err != nil { log.Println(err) - return err + return fmt.Errorf("failed to load sign key: %w", err) } + err = ctx.UseSignPrivateKey(signKey) if err != nil { - return err + return fmt.Errorf("failed to set sign key: %w", err) } ssl.SetSSLCtx(ctx) @@ -918,7 +792,7 @@ func ReadCertificateFiles(dirPath string) (map[string]crypto.GMDoubleCertKey, er files, err := os.ReadDir(dirPath) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read directory: %w", err) } for _, file := range files { @@ -942,11 +816,11 @@ func ReadCertificateFiles(dirPath string) (map[string]crypto.GMDoubleCertKey, er } func TestALPN(t *testing.T) { + t.Parallel() // Run server - server, err := newNTLSServerWithALPN(t, testCertDir, func(sslctx *Ctx) error { + server, err := newNTLSServerWithALPN(t, testCertDir, func(sslctx *ts.Ctx) error { return sslctx.SetCipherList("ECC-SM2-SM4-CBC-SM3") }) - if err != nil { t.Error(err) return @@ -956,12 +830,9 @@ func TestALPN(t *testing.T) { go server.RunForALPN() // Run Client - - caFile := "test/certs/sm2/chain-ca.crt" - connAddr := "127.0.0.1:4433" alpnProtocols := []string{"h3"} - ctx, err := NewCtxWithVersion(NTLS) + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return @@ -978,13 +849,14 @@ func TestALPN(t *testing.T) { return } - if err := ctx.LoadVerifyLocations(caFile, ""); err != nil { + if err := ctx.LoadVerifyLocations(testCaFile, ""); err != nil { t.Error(err) return } // Connect to the server - conn, err := DialSession("tcp", connAddr, ctx, InsecureSkipHostVerification, nil, "") + conn, err := ts.DialSession(server.Addr().Network(), server.Addr().String(), ctx, + ts.InsecureSkipHostVerification, nil, "") if err != nil { t.Log(err) return @@ -997,10 +869,10 @@ func TestALPN(t *testing.T) { // If there is an error, log it and terminate the test t.Log("Failed to get negotiated ALPN protocol:", err) return - } else { - t.Log("ALPN negotiated successfully", negotiatedProto) } + t.Log("ALPN negotiated successfully", negotiatedProto) + cipher, err := conn.CurrentCipher() if err != nil { t.Error(err) @@ -1009,8 +881,7 @@ func TestALPN(t *testing.T) { t.Log("current cipher", cipher) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -1021,14 +892,16 @@ func TestALPN(t *testing.T) { return } - if resp != request { + if resp != testRequest { t.Error("response data is not expected: ", resp) return } } -func newNTLSServerWithALPN(t *testing.T, testDir string, options ...func(sslctx *Ctx) error) (*echoServer, error) { - ctx, err := NewCtxWithVersion(NTLS) +func newNTLSServerWithALPN(t *testing.T, testDir string, options ...func(sslctx *ts.Ctx) error) (*echoServer, error) { + t.Helper() + + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return nil, err @@ -1050,75 +923,15 @@ func newNTLSServerWithALPN(t *testing.T, testDir string, options ...func(sslctx supportedProtos := []string{"h2", "http/1.1"} ctx.SetServerALPNProtos(supportedProtos) - encCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - signCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseEncryptCertificate(encCert); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignCertificate(signCert); err != nil { - t.Error(err) - return nil, err - } - - encKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.key")) - if err != nil { - t.Error(err) - return nil, err - } - - signKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.key")) - if err != nil { - t.Error(err) - return nil, err - } - - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) + err = ctxSetGMDoubleCertKey(ctx, filepath.Join(testDir, "server_sign.crt"), + filepath.Join(testDir, "server_sign.key"), filepath.Join(testDir, "server_enc.crt"), + filepath.Join(testDir, "server_enc.key")) if err != nil { t.Error(err) return nil, err } - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignPrivateKey(signKey); err != nil { - t.Error(err) - return nil, err - } - - lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + lis, err := ts.Listen("tcp", "localhost:0", ctx) if err != nil { t.Error(err) return nil, err @@ -1127,18 +940,22 @@ func newNTLSServerWithALPN(t *testing.T, testDir string, options ...func(sslctx return &echoServer{lis}, nil } -// TestSessionReuse Test session reuse +// TestSessionReuse Test session reuse. func TestSessionReuse(t *testing.T) { + t.Parallel() // Run server // Execute for loop to test various CacheModes - for _, cacheMode := range []SessionCacheModes{ - SessionCacheOff, - SessionCacheClient, - SessionCacheServer, - SessionCacheBoth, + for _, cacheMode := range []ts.SessionCacheModes{ + ts.SessionCacheOff, + ts.SessionCacheClient, + ts.SessionCacheServer, + ts.SessionCacheBoth, } { + cacheMode := cacheMode t.Run(fmt.Sprintf("cacheMode: %d", cacheMode), func(t *testing.T) { - server, err := newNTLSServerWithSessionReuse(t, testCertDir, cacheMode, func(sslctx *Ctx) error { + t.Parallel() + + server, err := newNTLSServerWithSessionReuse(t, testCertDir, cacheMode, func(sslctx *ts.Ctx) error { return sslctx.SetCipherList("ECC-SM2-SM4-CBC-SM3") }) if err != nil { @@ -1150,32 +967,32 @@ func TestSessionReuse(t *testing.T) { go server.Run() // Run client - caFile := "test/certs/sm2/chain-ca.crt" - connAddr := "127.0.0.1:4433" - - ctx, err := NewCtxWithVersion(NTLS) + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return } - ctx.SetOptions(NoTicket) + ctx.SetOptions(ts.NoTicket) + if err := ctx.SetCipherList("ECC-SM2-SM4-CBC-SM3"); err != nil { t.Error(err) return } - if err := ctx.LoadVerifyLocations(caFile, ""); err != nil { + if err := ctx.LoadVerifyLocations(testCaFile, ""); err != nil { t.Error(err) return } // Connect to the server, and get reused session, use session to connect again // Use a for loop to connect 2 times - var sessions = make([][]byte, 2) + sessions := make([][]byte, 2) + var session []byte for i := 0; i < 2; i++ { - conn, err := DialSession("tcp", connAddr, ctx, InsecureSkipHostVerification, session, "") + conn, err := ts.DialSession(server.Addr().Network(), server.Addr().String(), ctx, + ts.InsecureSkipHostVerification, session, "") if err != nil { t.Log(err) return @@ -1187,10 +1004,10 @@ func TestSessionReuse(t *testing.T) { t.Error(err) return } + session = sessions[i] - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -1201,7 +1018,7 @@ func TestSessionReuse(t *testing.T) { return } - if resp != request { + if resp != testRequest { t.Error("response data is not expected: ", resp) return } @@ -1210,25 +1027,31 @@ func TestSessionReuse(t *testing.T) { } switch cacheMode { - case SessionCacheOff, SessionCacheClient: + case ts.SessionCacheOff, ts.SessionCacheClient: if !bytes.Equal(sessions[0], sessions[1]) { t.Log("session is not reused") } else { t.Error("session is reused") } - case SessionCacheServer, SessionCacheBoth: + case ts.SessionCacheServer, ts.SessionCacheBoth: if !bytes.Equal(sessions[0], sessions[1]) { t.Error("session is not reused") } else { t.Log("session is reused") } + default: + t.Error("unexpected cache mode") } }) } } -func newNTLSServerWithSessionReuse(t *testing.T, testDir string, cacheMode SessionCacheModes, options ...func(sslctx *Ctx) error) (*echoServer, error) { - ctx, err := NewCtxWithVersion(NTLS) +func newNTLSServerWithSessionReuse(t *testing.T, testDir string, cacheMode ts.SessionCacheModes, + options ...func(sslctx *ts.Ctx) error, +) (*echoServer, error) { + t.Helper() + + ctx, err := ts.NewCtxWithVersion(ts.NTLS) if err != nil { t.Error(err) return nil, err @@ -1246,79 +1069,19 @@ func newNTLSServerWithSessionReuse(t *testing.T, testDir string, cacheMode Sessi return nil, err } - encCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - signCertPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.crt")) - if err != nil { - t.Error(err) - return nil, err - } - - encCert, err := crypto.LoadCertificateFromPEM(encCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signCert, err := crypto.LoadCertificateFromPEM(signCertPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseEncryptCertificate(encCert); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignCertificate(signCert); err != nil { - t.Error(err) - return nil, err - } - - encKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_enc.key")) - if err != nil { - t.Error(err) - return nil, err - } - - signKeyPEM, err := os.ReadFile(filepath.Join(testDir, "server_sign.key")) - if err != nil { - t.Error(err) - return nil, err - } - - encKey, err := crypto.LoadPrivateKeyFromPEM(encKeyPEM) - if err != nil { - t.Error(err) - return nil, err - } - - signKey, err := crypto.LoadPrivateKeyFromPEM(signKeyPEM) + err = ctxSetGMDoubleCertKey(ctx, filepath.Join(testDir, "server_sign.crt"), + filepath.Join(testDir, "server_sign.key"), filepath.Join(testDir, "server_enc.crt"), + filepath.Join(testDir, "server_enc.key")) if err != nil { t.Error(err) return nil, err } - if err := ctx.UseEncryptPrivateKey(encKey); err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseSignPrivateKey(signKey); err != nil { - t.Error(err) - return nil, err - } - // Set session reuse sessionCacheMode := ctx.SetSessionCacheMode(cacheMode) t.Log("session cache mode", sessionCacheMode, "new mode", cacheMode) - lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + lis, err := ts.Listen("tcp", "localhost:0", ctx) if err != nil { t.Error(err) return nil, err @@ -1328,6 +1091,8 @@ func newNTLSServerWithSessionReuse(t *testing.T, testDir string, cacheMode Sessi } func TestTLS13Connection(t *testing.T) { + t.Parallel() + // Run server server, err := newTLS13Server(t, "test/certs") if err != nil { @@ -1339,15 +1104,13 @@ func TestTLS13Connection(t *testing.T) { go server.Run() // Run client - connAddr := "127.0.0.1:4433" - - ctx, err := NewCtxWithVersion(TLSv1_3) + ctx, err := ts.NewCtxWithVersion(ts.TLSv1_3) if err != nil { t.Error(err) return } - conn, err := Dial("tcp", connAddr, ctx, InsecureSkipHostVerification, "") + conn, err := ts.Dial(server.Addr().Network(), server.Addr().String(), ctx, ts.InsecureSkipHostVerification, "") if err != nil { t.Log(err) return @@ -1369,8 +1132,7 @@ func TestTLS13Connection(t *testing.T) { t.Log("tls version", tlsVersion) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -1381,14 +1143,16 @@ func TestTLS13Connection(t *testing.T) { return } - if resp != request { + if resp != testRequest { t.Error("response data is not expected: ", resp) return } } -func newTLS13Server(t *testing.T, testDir string, options ...func(sslctx *Ctx) error) (*echoServer, error) { - ctx, err := NewCtxWithVersion(TLSv1_3) +func newTLS13Server(t *testing.T, testDir string, options ...func(sslctx *ts.Ctx) error) (*echoServer, error) { + t.Helper() + + ctx, err := ts.NewCtxWithVersion(ts.TLSv1_3) if err != nil { t.Error(err) return nil, err @@ -1435,7 +1199,7 @@ func newTLS13Server(t *testing.T, testDir string, options ...func(sslctx *Ctx) e return nil, err } - lis, err := Listen("tcp", "127.0.0.1:4433", ctx) + lis, err := ts.Listen("tcp", "localhost:0", ctx) if err != nil { t.Error(err) return nil, err @@ -1445,6 +1209,8 @@ func newTLS13Server(t *testing.T, testDir string, options ...func(sslctx *Ctx) e } func TestTLSv13SMCipher(t *testing.T) { + t.Parallel() + ciphers := []string{ TLSSMGCMCipher, TLSSMCCMCipher, @@ -1452,9 +1218,11 @@ func TestTLSv13SMCipher(t *testing.T) { testCertDir := "test/certs" for _, cipher := range ciphers { + cipher := cipher t.Run(cipher, func(t *testing.T) { + t.Parallel() // Run server - server, err := newTLSv13SMCipherServer(t, testCertDir, func(sslctx *Ctx) error { + server, err := newTLS13Server(t, testCertDir, func(sslctx *ts.Ctx) error { return sslctx.SetCipherSuites(cipher) }) if err != nil { @@ -1466,7 +1234,7 @@ func TestTLSv13SMCipher(t *testing.T) { go server.Run() // Run client - ctx, err := NewCtxWithVersion(TLSv1_3) + ctx, err := ts.NewCtxWithVersion(ts.TLSv1_3) if err != nil { t.Error(err) return @@ -1477,7 +1245,8 @@ func TestTLSv13SMCipher(t *testing.T) { return } - conn, err := Dial("tcp", "127.0.0.1:4433", ctx, InsecureSkipHostVerification, "") + conn, err := ts.Dial(server.Addr().Network(), server.Addr().String(), ctx, + ts.InsecureSkipHostVerification, "") if err != nil { t.Error(err) return @@ -1492,8 +1261,7 @@ func TestTLSv13SMCipher(t *testing.T) { t.Log("current cipher", cipher) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { + if _, err := conn.Write([]byte(testRequest)); err != nil { t.Error(err) return } @@ -1504,67 +1272,10 @@ func TestTLSv13SMCipher(t *testing.T) { return } - if resp != request { + if resp != testRequest { t.Error("response data is not expected: ", resp) return } }) } } - -func newTLSv13SMCipherServer(t *testing.T, testDir string, options ...func(sslctx *Ctx) error) (*echoServer, error) { - ctx, err := NewCtxWithVersion(TLSv1_3) - if err != nil { - t.Error(err) - return nil, err - } - - for _, f := range options { - if err := f(ctx); err != nil { - t.Error(err) - return nil, err - } - } - - certPEM, err := os.ReadFile(filepath.Join(testDir, "sm2-cert.pem")) - if err != nil { - t.Error(err) - return nil, err - } - - cert, err := crypto.LoadCertificateFromPEM(certPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UseCertificate(cert); err != nil { - t.Error(err) - return nil, err - } - - keyPEM, err := os.ReadFile(filepath.Join(testDir, "sm2.key")) - if err != nil { - t.Error(err) - return nil, err - } - - key, err := crypto.LoadPrivateKeyFromPEM(keyPEM) - if err != nil { - t.Error(err) - return nil, err - } - - if err := ctx.UsePrivateKey(key); err != nil { - t.Error(err) - return nil, err - } - - lis, err := Listen("tcp", "127.0.0.1:4433", ctx) - if err != nil { - t.Error(err) - return nil, err - } - - return &echoServer{lis}, nil -} diff --git a/pem.go b/pem.go index 37c06ad..f6d770a 100644 --- a/pem.go +++ b/pem.go @@ -18,15 +18,15 @@ import ( "regexp" ) -var pemSplit *regexp.Regexp = regexp.MustCompile(`(?sm)` + +var pemSplit = regexp.MustCompile(`(?sm)` + `(^-----[\s-]*?BEGIN.*?-----$` + `.*?` + `^-----[\s-]*?END.*?-----$)`) func SplitPEM(data []byte) [][]byte { var results [][]byte - for _, block := range pemSplit.FindAll(data, -1) { - results = append(results, block) - } + + results = append(results, pemSplit.FindAll(data, -1)...) + return results } diff --git a/shim.c b/shim.c index b005996..7cba859 100644 --- a/shim.c +++ b/shim.c @@ -29,10 +29,9 @@ #include "_cgo_export.h" -int X_tongsuogo_init(void) { +void X_tongsuogo_init(void) { SSL_load_error_strings(); SSL_library_init(); - return 0; } long X_SSL_set_options(SSL* ssl, long options) { @@ -218,7 +217,7 @@ int X_SSL_CTX_ticket_key_cb(SSL *s, unsigned char key_name[16], SSL_CTX* ssl_ctx = SSL_get_SSL_CTX(s); void* p = SSL_CTX_get_ex_data(ssl_ctx, get_ssl_ctx_idx()); // get the pointer to the go Ctx object and pass it back into the thunk - return go_ticket_key_cb_thunk(p, s, key_name, iv, cctx, hctx, enc); + return go_ticket_key_cb_thunk(p, key_name, cctx, hctx, enc); } int X_X509_add_ref(X509* x509) { diff --git a/shim.h b/shim.h index 7472f58..cd9b98d 100644 --- a/shim.h +++ b/shim.h @@ -31,7 +31,7 @@ #endif /* shim methods */ -extern int X_tongsuogo_init(void); +extern void X_tongsuogo_init(void); /* SSL methods */ extern long X_SSL_set_options(SSL* ssl, long options); diff --git a/sni.c b/sni.c index dded092..38101bc 100644 --- a/sni.c +++ b/sni.c @@ -17,9 +17,9 @@ #include int sni_cb(SSL *con, int *ad, void *arg) { - SSL_CTX* ssl_ctx = ssl_ctx = SSL_get_SSL_CTX(con); + SSL_CTX* ssl_ctx = SSL_get_SSL_CTX(con); void* p = SSL_CTX_get_ex_data(ssl_ctx, get_ssl_ctx_idx()); - return sni_cb_thunk(p, con, ad, arg); + return sniCbThunk(p, con, ad, arg); } int alpn_cb(SSL *ssl_conn, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg) { diff --git a/ssl.go b/ssl.go index b5742f9..ac5f1f4 100644 --- a/ssl.go +++ b/ssl.go @@ -32,37 +32,35 @@ const ( ) const ( - OPENSSL_NPN_NEGOTIATED C.int = C.OPENSSL_NPN_NEGOTIATED - OPENSSL_NPN_NO_OVERLAP C.int = C.OPENSSL_NPN_NO_OVERLAP + NPNNegotiated C.int = C.OPENSSL_NPN_NEGOTIATED + NPNNoOverlap C.int = C.OPENSSL_NPN_NO_OVERLAP ) -var ( - ssl_idx = C.X_SSL_new_index() -) +var sslIdx = C.X_SSL_new_index() //export get_ssl_idx func get_ssl_idx() C.int { - return ssl_idx + return sslIdx } type SSL struct { - ssl *C.SSL - verify_cb VerifyCallback + ssl *C.SSL + verifyCb VerifyCallback } //export go_ssl_verify_cb_thunk -func go_ssl_verify_cb_thunk(p unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CTX) C.int { +func go_ssl_verify_cb_thunk(callback unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CTX) C.int { defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: verify callback panic'd: %v", err) + // logger.Critf("openssl: verify callback panic'd: %v", err) os.Exit(1) } }() - verify_cb := (*SSL)(p).verify_cb + verifyCb := (*SSL)(callback).verifyCb // set up defaults just in case verify_cb is nil - if verify_cb != nil { - store := &CertificateStoreCtx{ctx: ctx} - if verify_cb(ok == 1, store) { + if verifyCb != nil { + store := &CertificateStoreCtx{ctx: ctx, sslCtx: nil} + if verifyCb(ok == 1, store) { ok = 1 } else { ok = 0 @@ -97,9 +95,9 @@ func (s *SSL) ClearOptions(options Options) Options { // SetVerify controls peer verification settings. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html -func (s *SSL) SetVerify(options VerifyOptions, verify_cb VerifyCallback) { - s.verify_cb = verify_cb - if verify_cb != nil { +func (s *SSL) SetVerify(options VerifyOptions, verifyCb VerifyCallback) { + s.verifyCb = verifyCb + if verifyCb != nil { C.SSL_set_verify(s.ssl, C.int(options), (*[0]byte)(C.X_SSL_verify_cb)) } else { C.SSL_set_verify(s.ssl, C.int(options), nil) @@ -109,19 +107,19 @@ func (s *SSL) SetVerify(options VerifyOptions, verify_cb VerifyCallback) { // SetVerifyMode controls peer verification setting. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html func (s *SSL) SetVerifyMode(options VerifyOptions) { - s.SetVerify(options, s.verify_cb) + s.SetVerify(options, s.verifyCb) } // SetVerifyCallback controls peer verification setting. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html -func (s *SSL) SetVerifyCallback(verify_cb VerifyCallback) { - s.SetVerify(s.VerifyMode(), verify_cb) +func (s *SSL) SetVerifyCallback(verifyCb VerifyCallback) { + s.SetVerify(s.VerifyMode(), verifyCb) } // GetVerifyCallback returns callback function. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html func (s *SSL) GetVerifyCallback() VerifyCallback { - return s.verify_cb + return s.verifyCb } // VerifyMode returns peer verification setting. See @@ -155,40 +153,44 @@ func (s *SSL) SetSSLCtx(ctx *Ctx) { C.SSL_set_SSL_CTX(s.ssl, ctx.ctx) } -//export sni_cb_thunk -func sni_cb_thunk(p unsafe.Pointer, con *C.SSL, ad unsafe.Pointer, arg unsafe.Pointer) C.int { +//export sniCbThunk +func sniCbThunk(callback unsafe.Pointer, con *C.SSL, ad unsafe.Pointer, arg unsafe.Pointer) C.int { + _, _ = ad, arg // unused + defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: verify callback sni panic'd: %v", err) + // logger.Critf("openssl: verify callback sni panic'd: %v", err) os.Exit(1) } }() - sni_cb := (*Ctx)(p).sni_cb + sniCb := (*Ctx)(callback).sniCb - s := &SSL{ssl: con} + s := &SSL{ssl: con, verifyCb: nil} // This attaches a pointer to our SSL struct into the SNI callback. C.SSL_set_ex_data(s.ssl, get_ssl_idx(), unsafe.Pointer(s.ssl)) // Note: this is ctx.sni_cb, not C.sni_cb - return C.int(sni_cb(s)) + return C.int(sniCb(s)) } //export alpn_cb_thunk -func alpn_cb_thunk(p unsafe.Pointer, con *C.SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) C.int { +func alpn_cb_thunk(callback unsafe.Pointer, con *C.SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, + inlen uint, arg unsafe.Pointer, +) C.int { defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: verify callback alpn panic'd: %v", err) + // logger.Critf("openssl: verify callback alpn panic'd: %v", err) os.Exit(1) } }() - alpn_cb := (*Ctx)(p).alpn_cb + alpnCb := (*Ctx)(callback).alpnCb - s := &SSL{ssl: con} + s := &SSL{ssl: con, verifyCb: nil} // This attaches a pointer to our SSL struct into the ALPN callback. C.SSL_set_ex_data(s.ssl, get_ssl_idx(), unsafe.Pointer(s.ssl)) // Ensure the out parameter is treated as a pointer to const unsigned char - return C.int(alpn_cb(s, out, outlen, in, inlen, arg)) + return C.int(alpnCb(s, out, outlen, in, inlen, arg)) } diff --git a/ssl_test.go b/ssl_test.go index fcc8d8f..c13d83d 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -12,25 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package tongsuogo_test import ( "bytes" "crypto/rand" "crypto/tls" "io" - "io/ioutil" "net" "sync" "testing" "time" + ts "github.com/tongsuo-project/tongsuo-go-sdk" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "github.com/tongsuo-project/tongsuo-go-sdk/utils" ) -var ( - certBytes = []byte(`-----BEGIN CERTIFICATE----- +const ( + certBytes = `-----BEGIN CERTIFICATE----- MIIExjCCAy6gAwIBAgIRAMqZUO0eR6sVZ3A8iG8bJK8wDQYJKoZIhvcNAQELBQAw ezEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMSgwJgYDVQQLDB90b21z YXd5ZXJAQi1GRzc5TUw3SC0wNDQ4LmxvY2FsMS8wLQYDVQQDDCZta2NlcnQgdG9t @@ -58,8 +58,8 @@ UjxAjLXvWmij6ilpMADnLQA0SH6s+9E2Aa5LTpEMDqXORcu+sq5/m3RuDtVxuYdU HNnVAmIdTLKC9CWnRfDxH8zPgIr/L8Yhdw92YST8hNqGQHeR0qoBcKYMHkpH6Ay4 yuKERO5LaAmjoXJW3n5Zal6jogf3wpiV1o4= -----END CERTIFICATE----- -`) - keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- +` + keyBytes = `-----BEGIN RSA PRIVATE KEY----- MIIG5AIBAAKCAYEAwoj05U9j+Y8xGs8xV9n0+6U6KJSS6eMn9xL4CW9qoxbpmNXw ZNqujLWH0LrceF34X6vqG3U6mrhQKt2g4ywWGOmxQk9OuJqwHrVFPZyheiS4zxwU D/CnsZ5cq1aOKH/PamWnaafNylBdj33o3CQXZBRkiYNandJGPjQsVjfI6I08y2RO @@ -98,83 +98,65 @@ JBfFkFXDvcbaYmpcOVCS1susPrPr8rgIm+vK6X+UoWE1RcCMGQMKObQIDGpE4IWe TDoukqQ8peoffk6mtiCnph9Cl2uqAgmmX+GyunEMIdF/ySG0CCcfz180GsQCucax +AxW2R7NJMAHvfeaYoLtSMYEVTS8sSpuIbRTfGuxbmMOD8a03gU6AA== -----END RSA PRIVATE KEY----- -`) - prime256v1KeyBytes = []byte(`-----BEGIN EC PRIVATE KEY----- -MHcCAQEEIB/XL0zZSsAu+IQF1AI/nRneabb2S126WFlvvhzmYr1KoAoGCCqGSM49 -AwEHoUQDQgAESSFGWwF6W1hoatKGPPorh4+ipyk0FqpiWdiH+4jIiU39qtOeZGSh -1QgSbzfdHxvoYI0FXM+mqE7wec0kIvrrHw== ------END EC PRIVATE KEY----- -`) - prime256v1CertBytes = []byte(`-----BEGIN CERTIFICATE----- -MIIChTCCAiqgAwIBAgIJAOQII2LQl4uxMAoGCCqGSM49BAMCMIGcMQswCQYDVQQG -EwJVUzEPMA0GA1UECAwGS2Fuc2FzMRAwDgYDVQQHDAdOb3doZXJlMR8wHQYDVQQK -DBZGYWtlIENlcnRpZmljYXRlcywgSW5jMUkwRwYDVQQDDEBhMWJkZDVmZjg5ZjQy -N2IwZmNiOTdlNDMyZTY5Nzg2NjI2ODJhMWUyNzM4MDhkODE0ZWJiZjY4ODBlYzA3 -NDljMB4XDTE3MTIxNTIwNDU1MVoXDTI3MTIxMzIwNDU1MVowgZwxCzAJBgNVBAYT -AlVTMQ8wDQYDVQQIDAZLYW5zYXMxEDAOBgNVBAcMB05vd2hlcmUxHzAdBgNVBAoM -FkZha2UgQ2VydGlmaWNhdGVzLCBJbmMxSTBHBgNVBAMMQGExYmRkNWZmODlmNDI3 -YjBmY2I5N2U0MzJlNjk3ODY2MjY4MmExZTI3MzgwOGQ4MTRlYmJmNjg4MGVjMDc0 -OWMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARJIUZbAXpbWGhq0oY8+iuHj6Kn -KTQWqmJZ2If7iMiJTf2q055kZKHVCBJvN90fG+hgjQVcz6aoTvB5zSQi+usfo1Mw -UTAdBgNVHQ4EFgQUfRYAFhlGM1wzvusyGrm26Vrbqm4wHwYDVR0jBBgwFoAUfRYA -FhlGM1wzvusyGrm26Vrbqm4wDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNJ -ADBGAiEA6PWNjm4B6zs3Wcha9qyDdfo1ILhHfk9rZEAGrnfyc2UCIQD1IDVJUkI4 -J/QVoOtP5DOdRPs/3XFy0Bk0qH+Uj5D7LQ== ------END CERTIFICATE----- -`) - ed25519CertBytes = []byte(`-----BEGIN CERTIFICATE----- -MIIBIzCB1gIUd0UUPX+qHrSKSVN9V/A3F1Eeti4wBQYDK2VwMDYxCzAJBgNVBAYT -AnVzMQ0wCwYDVQQKDARDU0NPMRgwFgYDVQQDDA9lZDI1NTE5X3Jvb3RfY2EwHhcN -MTgwODE3MDMzNzQ4WhcNMjgwODE0MDMzNzQ4WjAzMQswCQYDVQQGEwJ1czENMAsG -A1UECgwEQ1NDTzEVMBMGA1UEAwwMZWQyNTUxOV9sZWFmMCowBQYDK2VwAyEAKZZJ -zzlBcpjdbvzV0BRoaSiJKxbU6GnFeAELA0cHWR0wBQYDK2VwA0EAbfUJ7L7v3GDq -Gv7R90wQ/OKAc+o0q9eOrD6KRYDBhvlnMKqTMRVucnHXfrd5Rhmf4yHTvFTOhwmO -t/hpmISAAA== ------END CERTIFICATE----- -`) - ed25519KeyBytes = []byte(`-----BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIL3QVwyuusKuLgZwZn356UHk9u1REGHbNTLtFMPKNQSb ------END PRIVATE KEY----- -`) +` ) // NetPipe creates a TCP connection pipe and returns two connections. -func NetPipe(t testing.TB) (net.Conn, net.Conn) { - l, err := net.Listen("tcp", "localhost:0") +func NetPipe(tb testing.TB) (net.Conn, net.Conn) { + tb.Helper() + + lis, err := net.Listen("tcp", "localhost:0") if err != nil { - t.Fatal(err) + tb.Fatal(err) } - defer l.Close() + defer lis.Close() // Use Future pattern to create client connection asynchronously clientFuture := utils.NewFuture() go func() { - clientFuture.Set(net.Dial(l.Addr().Network(), l.Addr().String())) + clientFuture.Set(net.Dial(lis.Addr().Network(), lis.Addr().String())) }() - var errs utils.ErrorGroup - serverConn, err := l.Accept() + var ( + errs utils.ErrorGroup + conn net.Conn + ok bool + ) + + serverConn, err := lis.Accept() errs.Add(err) clientConn, err := clientFuture.Get() errs.Add(err) + if clientConn != nil { + conn, ok = clientConn.(net.Conn) + if !ok { + tb.Fatal("clientConn is not a net.Conn") + } + } + err = errs.Finalize() - if err != nil { - if serverConn != nil { - err := serverConn.Close() - if err != nil { - t.Fatal(err) - } + if err == nil { + return serverConn, conn + } + + if serverConn != nil { + err := serverConn.Close() + if err != nil { + tb.Fatal(err) } - if clientConn != nil { - err := clientConn.(net.Conn).Close() - if err != nil { - t.Fatal(err) - } + } + + if clientConn != nil { + err := conn.Close() + if err != nil { + tb.Fatal(err) } - t.Fatal(err) } - return serverConn, clientConn.(net.Conn) + + tb.Fatal(err) + + return nil, nil } // HandshakingConn interface extends net.Conn interface with Handshake method. @@ -184,20 +166,23 @@ type HandshakingConn interface { } // SimpleConnTest tests simple SSL/TLS connections. -func SimpleConnTest(t testing.TB, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { +func SimpleConnTest(tb testing.TB, constructor func( + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn), +) { + tb.Helper() // Create network pipe - serverConn, clientConn := NetPipe(t) + serverConn, clientConn := NetPipe(tb) defer serverConn.Close() defer clientConn.Close() data := "first test string\n" // Create SSL/TLS connections using provided constructor - server, client := constructor(t, serverConn, clientConn) + server, client := constructor(tb, serverConn, clientConn) defer closeBoth(server, client) var wg sync.WaitGroup + wg.Add(2) go func() { @@ -205,17 +190,17 @@ func SimpleConnTest(t testing.TB, constructor func( err := client.Handshake() if err != nil { - t.Fatal(err) + tb.Fatal(err) } _, err = io.Copy(client, bytes.NewReader([]byte(data))) if err != nil { - t.Fatal(err) + tb.Fatal(err) } err = client.Close() if err != nil { - t.Fatal(err) + tb.Fatal(err) } }() @@ -225,18 +210,19 @@ func SimpleConnTest(t testing.TB, constructor func( err := server.Handshake() if err != nil { - t.Fatal(err) + tb.Fatal(err) } buf := bytes.NewBuffer(make([]byte, 0, len(data))) + _, err = io.CopyN(buf, server, int64(len(data))) if err != nil { - t.Fatal(err) - } - if string(buf.Bytes()) != data { - t.Fatal("mismatched data") + tb.Fatal(err) } + if buf.String() != data { + tb.Fatal("mismatched data") + } }() wg.Wait() } @@ -244,7 +230,9 @@ func SimpleConnTest(t testing.TB, constructor func( // closeBoth closes two connections. func closeBoth(closer1, closer2 io.Closer) { var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done() closer1.Close() @@ -257,15 +245,18 @@ func closeBoth(closer1, closer2 io.Closer) { } // ClosingTest tests connection closing scenarios. -func ClosingTest(t testing.TB, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { +func ClosingTest(tb testing.TB, constructor func( + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn), +) { + tb.Helper() runTest := func(serverWrites bool) { // Create network pipe - serverConn, clientConn := NetPipe(t) + serverConn, clientConn := NetPipe(tb) defer serverConn.Close() defer clientConn.Close() - server, client := constructor(t, serverConn, clientConn) + + server, client := constructor(tb, serverConn, clientConn) defer closeBoth(server, client) // Determine who writes and who reads based on server_writes parameter @@ -279,13 +270,15 @@ func ClosingTest(t testing.TB, constructor func( } var wg sync.WaitGroup + wg.Add(2) go func() { defer wg.Done() + _, err := sslconn1.Write([]byte("hello")) if err != nil { - t.Fatal(err) + tb.Fatal(err) } sslconn1.Close() @@ -293,12 +286,14 @@ func ClosingTest(t testing.TB, constructor func( go func() { defer wg.Done() + data, err := io.ReadAll(sslconn2) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + if !bytes.Equal(data, []byte("hello")) { - t.Fatal("bytes don't match") + tb.Fatal("bytes don't match") } }() @@ -312,7 +307,9 @@ func ClosingTest(t testing.TB, constructor func( // ThroughputBenchmark benchmarks SSL/TLS connection throughput. func ThroughputBenchmark(b *testing.B, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn), +) { + b.Helper() // Create network pipe serverConn, clientConn := NetPipe(b) defer serverConn.Close() @@ -325,18 +322,22 @@ func ThroughputBenchmark(b *testing.B, constructor func( // Set benchmark parameters b.SetBytes(1024) data := make([]byte, b.N*1024) - _, err := io.ReadFull(rand.Reader, data[:]) + + _, err := io.ReadFull(rand.Reader, data) if err != nil { b.Fatal(err) } b.ResetTimer() + var wg sync.WaitGroup + wg.Add(2) go func() { defer wg.Done() - _, err = io.Copy(client, bytes.NewReader([]byte(data))) + + _, err = io.Copy(client, bytes.NewReader(data)) if err != nil { b.Error(err) } @@ -346,10 +347,12 @@ func ThroughputBenchmark(b *testing.B, constructor func( defer wg.Done() buf := &bytes.Buffer{} + _, err = io.CopyN(buf, server, int64(len(data))) if err != nil { b.Error(err) } + if !bytes.Equal(buf.Bytes(), data) { b.Error("mismatched data") } @@ -359,179 +362,219 @@ func ThroughputBenchmark(b *testing.B, constructor func( } // StdlibConstructor creates standard library SSL/TLS connections. -func StdlibConstructor(t testing.TB, serverConn, clientConn net.Conn) ( - server, client HandshakingConn) { - cert, err := tls.X509KeyPair(certBytes, keyBytes) +func StdlibConstructor(tb testing.TB, serverConn, clientConn net.Conn) (HandshakingConn, HandshakingConn) { + tb.Helper() + + cert, err := tls.X509KeyPair([]byte(certBytes), []byte(keyBytes)) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: true, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, } - server = tls.Server(serverConn, config) - client = tls.Client(clientConn, config) + + server := tls.Server(serverConn, tlsConfig) + client := tls.Client(clientConn, tlsConfig) + return server, client } // StdlibTLSv13Constructor creates standard library SSL/TLS connections with TLSv1.3. -func StdlibTLSv13Constructor(t testing.TB, serverConn, clientConn net.Conn) ( - server, client HandshakingConn) { - cert, err := tls.X509KeyPair(certBytes, keyBytes) +func StdlibTLSv13Constructor(tb testing.TB, serverConn, clientConn net.Conn) (HandshakingConn, HandshakingConn) { + tb.Helper() + + cert, err := tls.X509KeyPair([]byte(certBytes), []byte(keyBytes)) if err != nil { - t.Fatal(err) - } - config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS13, - MaxVersion: tls.VersionTLS13, + tb.Fatal(err) } - server = tls.Server(serverConn, config) - client = tls.Client(clientConn, config) + + tlsConfig := &tls.Config{} + tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.InsecureSkipVerify = true + tlsConfig.MinVersion = tls.VersionTLS13 + tlsConfig.MaxVersion = tls.VersionTLS13 + + server := tls.Server(serverConn, tlsConfig) + client := tls.Client(clientConn, tlsConfig) + return server, client } // passThruVerify is used to pass through certificate verification. -func passThruVerify(t testing.TB) func(bool, *CertificateStoreCtx) bool { - x := func(ok bool, store *CertificateStoreCtx) bool { +func passThruVerify(tb testing.TB) func(bool, *ts.CertificateStoreCtx) bool { + tb.Helper() + + check := func(ok bool, store *ts.CertificateStoreCtx) bool { cert := store.GetCurrentCert() if cert == nil { - t.Fatalf("Could not obtain cert from store\n") + tb.Fatalf("Could not obtain cert from store\n") } + sn := cert.GetSerialNumberHex() if len(sn) == 0 { - t.Fatalf("Could not obtain serial number from cert") + tb.Fatalf("Could not obtain serial number from cert") } + return ok } - return x + + return check } // OpenSSLConstructor creates OpenSSL SSL/TLS connections. -func OpenSSLConstructor(t testing.TB, serverConn, clientConn net.Conn) ( - server, client HandshakingConn) { - ctx, err := NewCtx() +func OpenSSLConstructor(tb testing.TB, serverConn, clientConn net.Conn) (HandshakingConn, HandshakingConn) { + tb.Helper() + + ctx, err := ts.NewCtx() if err != nil { - t.Fatal(err) + tb.Fatal(err) } - ctx.SetVerify(VerifyNone, passThruVerify(t)) - key, err := crypto.LoadPrivateKeyFromPEM(keyBytes) + + ctx.SetVerify(ts.VerifyNone, passThruVerify(tb)) + + key, err := crypto.LoadPrivateKeyFromPEM([]byte(keyBytes)) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + err = ctx.UsePrivateKey(key) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - cert, err := crypto.LoadCertificateFromPEM(certBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(certBytes)) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + err = ctx.UseCertificate(cert) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + err = ctx.SetCipherList("AES128-SHA") if err != nil { - t.Fatal(err) + tb.Fatal(err) } - server, err = Server(serverConn, ctx) + + server, err := ts.Server(serverConn, ctx) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - client, err = Client(clientConn, ctx) + + client, err := ts.Client(clientConn, ctx) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + return server, client } // OpenSSLTLSv3Constructor function is used to create SSL/TLS connections for OpenSSL and TLSv3. -func OpenSSLTLSv3Constructor(t testing.TB, serverConn, clientConn net.Conn) ( - server, client HandshakingConn) { - ctx, err := NewCtxWithVersion(SSLv3) +func OpenSSLTLSv3Constructor(tb testing.TB, serverConn, clientConn net.Conn) (HandshakingConn, HandshakingConn) { + tb.Helper() + + ctx, err := ts.NewCtxWithVersion(ts.SSLv3) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - ctx.SetVerify(VerifyNone, passThruVerify(t)) - key, err := crypto.LoadPrivateKeyFromPEM(keyBytes) + + ctx.SetVerify(ts.VerifyNone, passThruVerify(tb)) + + key, err := crypto.LoadPrivateKeyFromPEM([]byte(keyBytes)) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + err = ctx.UsePrivateKey(key) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - cert, err := crypto.LoadCertificateFromPEM(certBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(certBytes)) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + err = ctx.UseCertificate(cert) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + err = ctx.SetCipherList("AES128-SHA") if err != nil { - t.Fatal(err) + tb.Fatal(err) } - server, err = Server(serverConn, ctx) + + server, err := ts.Server(serverConn, ctx) if err != nil { - t.Fatal(err) + tb.Fatal(err) } - client, err = Client(clientConn, ctx) + + client, err := ts.Client(clientConn, ctx) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + return server, client } // StdlibOpenSSLConstructor function is used to create SSL/TLS connections for the standard library and OpenSSL. -func StdlibOpenSSLConstructor(t testing.TB, serverConn, clientConn net.Conn) ( - server, client HandshakingConn) { - serverStd, _ := StdlibConstructor(t, serverConn, clientConn) - _, clientSsl := OpenSSLConstructor(t, serverConn, clientConn) +func StdlibOpenSSLConstructor(tb testing.TB, serverConn, clientConn net.Conn) (HandshakingConn, HandshakingConn) { + tb.Helper() + + serverStd, _ := StdlibConstructor(tb, serverConn, clientConn) + _, clientSsl := OpenSSLConstructor(tb, serverConn, clientConn) + return serverStd, clientSsl } // OpenSSLStdlibConstructor function is used to create SSL/TLS connections for OpenSSL and the standard library. -func OpenSSLStdlibConstructor(t testing.TB, serverConn, clientConn net.Conn) ( - server, client HandshakingConn) { - _, clientStd := StdlibConstructor(t, serverConn, clientConn) - serverSsl, _ := OpenSSLConstructor(t, serverConn, clientConn) +func OpenSSLStdlibConstructor(tb testing.TB, serverConn, clientConn net.Conn) (HandshakingConn, HandshakingConn) { + tb.Helper() + + _, clientStd := StdlibConstructor(tb, serverConn, clientConn) + serverSsl, _ := OpenSSLConstructor(tb, serverConn, clientConn) + return serverSsl, clientStd } // TestStdlibSimple function is used to test simple connections of the standard library. func TestStdlibSimple(t *testing.T) { + t.Parallel() SimpleConnTest(t, StdlibConstructor) } // TestStdlibTLSv13Simple function is used to test simple connections of the standard library with TLSv1.3. func TestStdlibTLSv13Simple(t *testing.T) { + t.Parallel() SimpleConnTest(t, StdlibTLSv13Constructor) } // TestOpenSSLSimple function is used to test simple connections of OpenSSL. func TestOpenSSLSimple(t *testing.T) { + t.Parallel() SimpleConnTest(t, OpenSSLConstructor) } // TestStdlibClosing function is used to test closing connections of the standard library. func TestStdlibClosing(t *testing.T) { + t.Parallel() ClosingTest(t, StdlibConstructor) } // TestStdlibTLSv13Closing function is used to test closing connections of the standard library with TLSv1.3. func TestStdlibTLSv13Closing(t *testing.T) { + t.Parallel() ClosingTest(t, StdlibTLSv13Constructor) } -// TODO fix this -//func TestOpenSSLClosing(t *testing.T) { -// ClosingTest(t, OpenSSLConstructor) -//} +func TestOpenSSLClosing(t *testing.T) { + t.Parallel() + ClosingTest(t, OpenSSLConstructor) +} // BenchmarkStdlibThroughput function is used to benchmark the throughput of the standard library. func BenchmarkStdlibThroughput(b *testing.B) { @@ -550,21 +593,25 @@ func BenchmarkOpenSSLThroughput(b *testing.B) { // TestStdlibOpenSSLSimple function is used to test simple connections of the standard library and OpenSSL. func TestStdlibOpenSSLSimple(t *testing.T) { + t.Parallel() SimpleConnTest(t, StdlibOpenSSLConstructor) } // TestOpenSSLStdlibSimple function is used to test simple connections of OpenSSL and the standard library. func TestOpenSSLStdlibSimple(t *testing.T) { + t.Parallel() SimpleConnTest(t, OpenSSLStdlibConstructor) } // TestStdlibOpenSSLClosing function is used to test closing connections of the standard library and OpenSSL. func TestStdlibOpenSSLClosing(t *testing.T) { + t.Parallel() ClosingTest(t, StdlibOpenSSLConstructor) } // TestOpenSSLStdlibClosing function is used to test closing connections of OpenSSL and the standard library. func TestOpenSSLStdlibClosing(t *testing.T) { + t.Parallel() ClosingTest(t, OpenSSLStdlibConstructor) } @@ -579,10 +626,14 @@ func BenchmarkOpenSSLStdlibThroughput(b *testing.B) { } // FullDuplexRenegotiationTest function is used to test full-duplex renegotiation. -func FullDuplexRenegotiationTest(t testing.TB, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { +func FullDuplexRenegotiationTest(tb testing.TB, constructor func( + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn), +) { + tb.Helper() + SSLRecordSize := 16 * 1024 - serverConn, clientConn := NetPipe(t) + + serverConn, clientConn := NetPipe(tb) defer serverConn.Close() defer clientConn.Close() @@ -590,38 +641,45 @@ func FullDuplexRenegotiationTest(t testing.TB, constructor func( times := 256 dataLen := 4 * SSLRecordSize data1 := make([]byte, dataLen) - _, err := io.ReadFull(rand.Reader, data1[:]) + + _, err := io.ReadFull(rand.Reader, data1) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + data2 := make([]byte, dataLen) - _, err = io.ReadFull(rand.Reader, data1[:]) + + _, err = io.ReadFull(rand.Reader, data2) if err != nil { - t.Fatal(err) + tb.Fatal(err) } // Create SSL/TLS connections - server, client := constructor(t, serverConn, clientConn) + server, client := constructor(tb, serverConn, clientConn) defer closeBoth(server, client) var wg sync.WaitGroup sendFunc := func(sender HandshakingConn, data []byte) { defer wg.Done() + for i := 0; i < times; i++ { if i == times/2 { wg.Add(1) + go func() { defer wg.Done() + err := sender.Handshake() if err != nil { - t.Fatal(err) + tb.Fatal(err) } }() } + _, err := sender.Write(data) if err != nil { - t.Fatal(err) + tb.Fatal(err) } } } @@ -631,17 +689,19 @@ func FullDuplexRenegotiationTest(t testing.TB, constructor func( buf := make([]byte, len(data)) for i := 0; i < times; i++ { - n, err := io.ReadFull(receiver, buf[:]) + n, err := io.ReadFull(receiver, buf) if err != nil { - t.Fatal(err) + tb.Fatal(err) } + if !bytes.Equal(buf[:n], data) { - t.Fatal(err) + tb.Fatal(err) } } } wg.Add(4) + go recvFunc(server, data1) go sendFunc(client, data1) go sendFunc(server, data2) @@ -651,125 +711,150 @@ func FullDuplexRenegotiationTest(t testing.TB, constructor func( // TestStdlibFullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library. func TestStdlibFullDuplexRenegotiation(t *testing.T) { + t.Parallel() FullDuplexRenegotiationTest(t, StdlibConstructor) } -// TestStdlibTLSv13FullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library with TLSv1.3. -func TestStdlibTLSv13FullDuplexRenegotiation(t *testing.T) { - FullDuplexRenegotiationTest(t, StdlibTLSv13Constructor) -} - // TestOpenSSLFullDuplexRenegotiation function is used to test full-duplex renegotiation of OpenSSL. func TestOpenSSLFullDuplexRenegotiation(t *testing.T) { + t.Parallel() FullDuplexRenegotiationTest(t, OpenSSLConstructor) } -// TestOpenSSLStdlibFullDuplexRenegotiation function is used to test full-duplex renegotiation of OpenSSL and the standard library. +// TestOpenSSLStdlibFullDuplexRenegotiation function is used to test full-duplex renegotiation of OpenSSL and the +// standard library. func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) { + t.Parallel() FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor) } -// TestStdlibOpenSSLFullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library and OpenSSL. +// TestStdlibOpenSSLFullDuplexRenegotiation function is used to test full-duplex renegotiation of the standard library +// and OpenSSL. func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) { + t.Parallel() FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor) } +func startTLSServer(t *testing.T, sslListener net.Listener, payloadSize int64, loops int, sleep time.Duration) { + t.Helper() + + for { + conn, err := sslListener.Accept() + if err != nil { + t.Error("failed to accept: ", err) + continue + } + + go func() { + defer func() { + err = conn.Close() + if err != nil { + t.Error("failed to close: ", err) + } + }() + + for i := 0; i < loops; i++ { + _, err := io.Copy(io.Discard, io.LimitReader(conn, payloadSize)) + if err != nil { + t.Error("failed to read: ", err) + return + } + + _, err = io.Copy(conn, io.LimitReader(rand.Reader, payloadSize)) + if err != nil { + t.Error("failed to write: ", err) + return + } + } + + time.Sleep(sleep) + }() + } +} + // LotsOfConns function is used to test the situation of a large number of connections. func LotsOfConns(t *testing.T, payloadSize int64, loops, clients int, sleep time.Duration, newListener func(net.Listener) net.Listener, - newClient func(net.Conn) (net.Conn, error)) { + newClient func(net.Conn) (net.Conn, error), +) { + t.Helper() + tcpListener, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } + sslListener := newListener(tcpListener) - go func() { - for { - conn, err := sslListener.Accept() - if err != nil { - t.Error("failed accept: ", err) - continue - } - go func() { - defer func() { - err = conn.Close() - if err != nil { - t.Error("failed closing: ", err) - } - }() - for i := 0; i < loops; i++ { - _, err := io.Copy(ioutil.Discard, - io.LimitReader(conn, payloadSize)) - if err != nil { - t.Error("failed reading: ", err) - return - } - _, err = io.Copy(conn, io.LimitReader(rand.Reader, - payloadSize)) - if err != nil { - t.Error("failed writing: ", err) - return - } - } - time.Sleep(sleep) - }() - } - }() + go startTLSServer(t, sslListener, payloadSize, loops, sleep) // Create multiple client connections var wg sync.WaitGroup + for i := 0; i < clients; i++ { tcpClient, err := net.Dial(tcpListener.Addr().Network(), tcpListener.Addr().String()) if err != nil { t.Fatal(err) } + sslClient, err := newClient(tcpClient) if err != nil { t.Fatal(err) } + wg.Add(1) - go func(i int) { + + go func(_ int) { defer func() { err = sslClient.Close() if err != nil { - t.Error("failed closing: ", err) + t.Error("failed to close: ", err) } + wg.Done() }() + for i := 0; i < loops; i++ { // Write and read data _, err := io.Copy(sslClient, io.LimitReader(rand.Reader, payloadSize)) if err != nil { - t.Error("failed writing: ", err) + t.Error("failed to write: ", err) return } - _, err = io.Copy(ioutil.Discard, + + _, err = io.Copy(io.Discard, io.LimitReader(sslClient, payloadSize)) if err != nil { - t.Error("failed reading: ", err) + t.Error("failed to read: ", err) return } } + time.Sleep(sleep) }(i) } + wg.Wait() } -// TestStdlibLotsOfConns function is used to test the situation of a large number of connections of the standard library. +// TestStdlibLotsOfConns function is used to test the situation of a large number of connections of the standard +// library. func TestStdlibLotsOfConns(t *testing.T) { + t.Parallel() + // Load certificate and configure TLS - tlsCert, err := tls.X509KeyPair(certBytes, keyBytes) + tlsCert, err := tls.X509KeyPair([]byte(certBytes), []byte(keyBytes)) if err != nil { t.Fatal(err) } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - InsecureSkipVerify: true, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} + + tlsConfig := &tls.Config{} + tlsConfig.Certificates = []tls.Certificate{tlsCert} + tlsConfig.InsecureSkipVerify = true + tlsConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA} + // Execute large number of connections test LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { @@ -779,19 +864,23 @@ func TestStdlibLotsOfConns(t *testing.T) { }) } -// TestStdlibTLSv13LotsOfConns function is used to test the situation of a large number of connections of the standard library with TLSv1.3. +// TestStdlibTLSv13LotsOfConns function is used to test the situation of a large number of connections of the standard +// library with TLSv1.3. func TestStdlibTLSv13LotsOfConns(t *testing.T) { + t.Parallel() + // Load certificate and configure TLS - tlsCert, err := tls.X509KeyPair(certBytes, keyBytes) + tlsCert, err := tls.X509KeyPair([]byte(certBytes), []byte(keyBytes)) if err != nil { t.Fatal(err) } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS13, - MaxVersion: tls.VersionTLS13, - } + + tlsConfig := &tls.Config{} + tlsConfig.Certificates = []tls.Certificate{tlsCert} + tlsConfig.InsecureSkipVerify = true + tlsConfig.MinVersion = tls.VersionTLS13 + tlsConfig.MaxVersion = tls.VersionTLS13 + // Execute large number of connections test LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { @@ -803,27 +892,34 @@ func TestStdlibTLSv13LotsOfConns(t *testing.T) { // TestOpenSSLLotsOfConns function is used to test the situation of a large number of connections of OpenSSL. func TestOpenSSLLotsOfConns(t *testing.T) { + t.Parallel() + // Create SSL context and configure - ctx, err := NewCtx() + ctx, err := ts.NewCtx() if err != nil { t.Fatal(err) } - key, err := crypto.LoadPrivateKeyFromPEM(keyBytes) + + key, err := crypto.LoadPrivateKeyFromPEM([]byte(keyBytes)) if err != nil { t.Fatal(err) } + err = ctx.UsePrivateKey(key) if err != nil { t.Fatal(err) } - cert, err := crypto.LoadCertificateFromPEM(certBytes) + + cert, err := crypto.LoadCertificateFromPEM([]byte(certBytes)) if err != nil { t.Fatal(err) } + err = ctx.UseCertificate(cert) if err != nil { t.Fatal(err) } + err = ctx.SetCipherList("AES128-SHA") if err != nil { t.Fatal(err) @@ -831,8 +927,8 @@ func TestOpenSSLLotsOfConns(t *testing.T) { // Execute large number of connections test LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { - return NewListener(l, ctx) + return ts.NewListener(l, ctx) }, func(c net.Conn) (net.Conn, error) { - return Client(c, ctx) + return ts.Client(c, ctx) }) } diff --git a/tickets.go b/tickets.go index 8ab07d7..14e0cbe 100644 --- a/tickets.go +++ b/tickets.go @@ -102,58 +102,56 @@ func (t *TicketStore) digestEngine() *C.ENGINE { const ( // instruct to do a handshake - ticket_resp_requireHandshake = 0 + ticketRespRequireHandshake = 0 // crypto context is set up correctly - ticket_resp_sessionOk = 1 + ticketRespSessionOk = 1 // crypto context is ok, but the ticket should be reissued - ticket_resp_renewSession = 2 + ticketRespRenewSession = 2 // we had a problem that shouldn't fall back to doing a handshake - ticket_resp_error = -1 - + ticketRespError = -1 // asked to create session crypto context - ticket_req_newSession = 1 + ticketReqNewSession = 1 // asked to load crypto context for a previous session - ticket_req_lookupSession = 0 + ticketReqLookupSession = 0 ) //export go_ticket_key_cb_thunk -func go_ticket_key_cb_thunk(p unsafe.Pointer, s *C.SSL, key_name *C.uchar, - iv *C.uchar, cctx *C.EVP_CIPHER_CTX, hctx *C.HMAC_CTX, enc C.int) C.int { - +func go_ticket_key_cb_thunk(pctx unsafe.Pointer, keyName *C.uchar, cctx *C.EVP_CIPHER_CTX, hctx *C.HMAC_CTX, enc C.int, +) C.int { // no panic's allowed. it's super hard to guarantee any state at this point // so just abort everything. defer func() { if err := recover(); err != nil { - //logger.Critf("openssl: ticket key callback panic'd: %v", err) + // logger.Critf("openssl: ticket key callback panic'd: %v", err) os.Exit(1) } }() - ctx := (*Ctx)(p) - store := ctx.ticket_store + ctx := (*Ctx)(pctx) + store := ctx.ticketStore if store == nil { - // TODO(jeff): should this be an error condition? it doesn't make sense + // should this be an error condition? it doesn't make sense // to be called if we don't have a store I believe, but that's probably // not worth aborting the handshake which is what I believe returning // an error would do. - return ticket_resp_requireHandshake + return ticketRespRequireHandshake } - ctx.ticket_store_mu.Lock() - defer ctx.ticket_store_mu.Unlock() + ctx.ticketStoreMu.Lock() + defer ctx.ticketStoreMu.Unlock() switch enc { - case ticket_req_newSession: + case ticketReqNewSession: key := store.Keys.Current() if key == nil { key = store.Keys.New() if key == nil { - return ticket_resp_requireHandshake + return ticketRespRequireHandshake } } C.memcpy( - unsafe.Pointer(key_name), + unsafe.Pointer(keyName), unsafe.Pointer(&key.Name[0]), KeyNameSize) C.EVP_EncryptInit_ex( @@ -169,21 +167,21 @@ func go_ticket_key_cb_thunk(p unsafe.Pointer, s *C.SSL, key_name *C.uchar, (*C.EVP_MD)(store.DigestCtx.Digest.Ptr()), store.digestEngine()) - return ticket_resp_sessionOk + return ticketRespSessionOk - case ticket_req_lookupSession: + case ticketReqLookupSession: var name TicketName C.memcpy( unsafe.Pointer(&name[0]), - unsafe.Pointer(key_name), + unsafe.Pointer(keyName), KeyNameSize) key := store.Keys.Lookup(name) if key == nil { - return ticket_resp_requireHandshake + return ticketRespRequireHandshake } if store.Keys.Expired(name) { - return ticket_resp_requireHandshake + return ticketRespRequireHandshake } C.EVP_DecryptInit_ex( @@ -200,20 +198,20 @@ func go_ticket_key_cb_thunk(p unsafe.Pointer, s *C.SSL, key_name *C.uchar, store.digestEngine()) if store.Keys.ShouldRenew(name) { - return ticket_resp_renewSession + return ticketRespRenewSession } - return ticket_resp_sessionOk + return ticketRespSessionOk default: - return ticket_resp_error + return ticketRespError } } // SetTicketStore sets the ticket store for the context so that clients can do // ticket based session resumption. If the store is nil, the func (c *Ctx) SetTicketStore(store *TicketStore) { - c.ticket_store = store + c.ticketStore = store if store == nil { C.X_SSL_CTX_set_tlsext_ticket_key_cb(c.ctx, nil) diff --git a/utils/errors.go b/utils/errors.go index bab314c..7cb8e18 100644 --- a/utils/errors.go +++ b/utils/errors.go @@ -15,16 +15,16 @@ package utils import ( - "errors" + "fmt" "strings" ) -// ErrorGroup collates errors +// ErrorGroup collates errors. type ErrorGroup struct { Errors []error } -// Add adds an error to an existing error group +// Add adds an error to an existing error group. func (e *ErrorGroup) Add(err error) { if err != nil { e.Errors = append(e.Errors, err) @@ -39,12 +39,15 @@ func (e *ErrorGroup) Finalize() error { if len(e.Errors) == 0 { return nil } + if len(e.Errors) == 1 { return e.Errors[0] } + msgs := make([]string, 0, len(e.Errors)) for _, err := range e.Errors { msgs = append(msgs, err.Error()) } - return errors.New(strings.Join(msgs, "\n")) + + return fmt.Errorf("errors: %s", strings.Join(msgs, "\n")) } diff --git a/utils/future.go b/utils/future.go index fa1bbbf..ce9eb67 100644 --- a/utils/future.go +++ b/utils/future.go @@ -35,6 +35,7 @@ type Future struct { // NewFuture returns an initialized and ready Future. func NewFuture() *Future { mutex := &sync.Mutex{} + return &Future{ mutex: mutex, cond: sync.NewCond(mutex), @@ -45,35 +46,40 @@ func NewFuture() *Future { } // Get blocks until the Future has a value set. -func (self *Future) Get() (interface{}, error) { - self.mutex.Lock() - defer self.mutex.Unlock() +func (f *Future) Get() (interface{}, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + for { - if self.received { - return self.val, self.err + if f.received { + return f.val, f.err } - self.cond.Wait() + + f.cond.Wait() } } // Fired returns whether or not a value has been set. If Fired is true, Get // won't block. -func (self *Future) Fired() bool { - self.mutex.Lock() - defer self.mutex.Unlock() - return self.received +func (f *Future) Fired() bool { + f.mutex.Lock() + defer f.mutex.Unlock() + + return f.received } // Set provides the value to present and future Get calls. If Set has already // been called, this is a no-op. -func (self *Future) Set(val interface{}, err error) { - self.mutex.Lock() - defer self.mutex.Unlock() - if self.received { +func (f *Future) Set(val interface{}, err error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if f.received { return } - self.received = true - self.val = val - self.err = err - self.cond.Broadcast() + + f.received = true + f.val = val + f.err = err + f.cond.Broadcast() }