diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a54e710..2787110 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -35,7 +35,7 @@ jobs: - name: Build Tongsuo run: | cd Tongsuo - ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 make -j4 make install @@ -69,7 +69,7 @@ jobs: - name: Build Tongsuo run: | cd Tongsuo - ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 make -j4 make install @@ -108,7 +108,7 @@ jobs: - name: Build Tongsuo Static run: | cd tongsuo - ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls no-shared + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 no-shared make -j4 make install @@ -147,7 +147,7 @@ jobs: run: | mkdir _build cd _build - perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls + perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls enable-export-sm4 nmake /S nmake install working-directory: Tongsuo diff --git a/README.md b/README.md index eae4ca7..02d51e9 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ cd Tongsuo git checkout 8.3-stable -./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls +./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls enable-export-sm4 make -j make install ``` diff --git a/crypto/init.go b/crypto/init.go index 5be290c..6c0312c 100644 --- a/crypto/init.go +++ b/crypto/init.go @@ -48,6 +48,7 @@ var ( ErrInternalError = errors.New("internal error") ErrEmptyKey = errors.New("empty key") ErrNoData = errors.New("no data") + ErrInvalidKeySize = errors.New("invalid key size") ) func init() { diff --git a/crypto/shim.h b/crypto/shim.h index fb3b73f..70f3851 100644 --- a/crypto/shim.h +++ b/crypto/shim.h @@ -25,6 +25,7 @@ #include #include #include +#include /* shim methods */ extern int X_tscrypto_init(); diff --git a/crypto/sm4/sm4.go b/crypto/sm4/sm4.go index 5583fe0..74904aa 100644 --- a/crypto/sm4/sm4.go +++ b/crypto/sm4/sm4.go @@ -12,11 +12,18 @@ import "C" import ( "bytes" + "crypto/cipher" "fmt" + "unsafe" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) +const ( + BlockSize = 16 + KeySize = 16 +) + type Encrypter interface { // crypto.EncryptionCipherCtx SetPadding(pad bool) @@ -50,6 +57,50 @@ type sm4Decrypter struct { tag []byte } +type sm4Cipher struct { + rk [32]uint32 +} + +func (c *sm4Cipher) BlockSize() int { + return BlockSize +} + +func NewCipher(key []byte) (cipher.Block, error) { + if len(key) != KeySize { + return nil, fmt.Errorf("invalid key size: %w", crypto.ErrInvalidKeySize) + } + + cipher := &sm4Cipher{} + ret := C.SM4_set_key((*C.uchar)(&key[0]), (*C.SM4_KEY)(unsafe.Pointer(&cipher.rk))) + if ret != 1 { + return nil, fmt.Errorf("failed to set key: %w", crypto.ErrInternalError) + } + + return cipher, nil +} + +func (c *sm4Cipher) Encrypt(dst, src []byte) { + if len(src) < BlockSize { + panic("sm4: input not full block") + } + if len(dst) < BlockSize { + panic("sm4: output not full block") + } + + C.SM4_encrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk))) +} + +func (c *sm4Cipher) Decrypt(dst, src []byte) { + if len(src) < BlockSize { + panic("sm4: input not full block") + } + if len(dst) < BlockSize { + panic("sm4: output not full block") + } + + C.SM4_decrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk))) +} + func getSM4Cipher(mode int) (*crypto.Cipher, error) { var cipher *crypto.Cipher var err error diff --git a/crypto/sm4/sm4_test.go b/crypto/sm4/sm4_test.go index a81b71c..f25de20 100644 --- a/crypto/sm4/sm4_test.go +++ b/crypto/sm4/sm4_test.go @@ -9,6 +9,7 @@ package sm4_test import ( "bytes" + "crypto/cipher" "encoding/hex" "strings" "testing" @@ -20,6 +21,179 @@ import ( const hexPlainText1 = `AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFE EEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA` +func TestSM4ECBWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("681EDF34D206965E86B3E94F536E4246") + + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal("failed to create SM4 cipher: ", err) + } + + cipherText1 := make([]byte, len(plainText)) + + for i := 0; i < len(plainText); i += block.BlockSize() { + block.Encrypt(cipherText1[i:i+block.BlockSize()], plainText[i:i+block.BlockSize()]) + } + + if !bytes.Equal(cipherText1, cipherText) { + t.Fatalf("exp:%x got:%x", cipherText, cipherText1) + } + + plainText1 := make([]byte, len(cipherText1)) + + for i := 0; i < len(cipherText1); i += block.BlockSize() { + block.Decrypt(plainText1[i:i+block.BlockSize()], cipherText1[i:i+block.BlockSize()]) + } + + if !bytes.Equal(plainText, plainText1) { + t.Fatalf("exp:%x got:%x", plainText, plainText1) + } +} + +func testCryptWithCipherBlock(t *testing.T, mode string, key, iv, plainText, cipherText []byte) { + t.Helper() + + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal("failed to create SM4 cipher: ", err) + } + + cipherText1 := make([]byte, len(plainText)) + + switch mode { + case "CBC": + stream := cipher.NewCBCEncrypter(block, iv) + stream.CryptBlocks(cipherText1, plainText) + case "CFB": + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(cipherText1, plainText) + case "OFB": + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(cipherText1, plainText) + case "CTR": + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(cipherText1, plainText) + } + + if !bytes.Equal(cipherText1, cipherText) { + t.Fatalf("exp:%x got:%x", cipherText, cipherText1) + } + + plainText1 := make([]byte, len(plainText)) + + switch mode { + case "CBC": + stream := cipher.NewCBCDecrypter(block, iv) + stream.CryptBlocks(plainText1, cipherText1) + case "CFB": + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(plainText1, cipherText1) + case "OFB": + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(plainText1, cipherText1) + case "CTR": + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(plainText1, cipherText1) + } + + if !bytes.Equal(plainText, plainText1) { + t.Fatalf("exp:%x got:%x", plainText, plainText1) + } +} + +func TestSM4CBCWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") + + testCryptWithCipherBlock(t, "CBC", key, iv, plainText, cipherText) +} + +func TestSM4CFBWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A70569ED258A85A0467CC92AAB393DD978995") + + testCryptWithCipherBlock(t, "CFB", key, iv, plainText, cipherText) +} + +func TestSM4OFBWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A7056F2075D28B5235F58D50027E4177D2BCE") + + testCryptWithCipherBlock(t, "OFB", key, iv, plainText, cipherText) +} + +func TestSM4CTRWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + hexCipherText := `C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638 +B361FD5FFACD942F081485A83CA35D` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) + + testCryptWithCipherBlock(t, "CTR", key, iv, plainText, cipherText) +} + +func TestSM4GCMWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("00001234567800000000ABCD") + aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2") + tag, _ := hex.DecodeString("83DE3541E4C2B58177E065A9BF7B62EC") + hexCipherText := `17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A5 +6834CBCF98C397B4024A2691233B8D` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) + + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal("failed to create SM4 cipher: ", err) + } + + stream, err := cipher.NewGCM(block) + if err != nil { + t.Fatal("failed to create GCM: ", err) + } + + cipherText1 := stream.Seal(nil, iv, plainText, aad) + + if !bytes.Equal(cipherText1, append(cipherText, tag...)) { + t.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...)) + } + + stream2, err := cipher.NewGCM(block) + if err != nil { + t.Fatal("failed to create GCM: ", err) + } + + plainText1, err := stream2.Open(nil, iv, cipherText1, aad) + if err != nil { + t.Fatal("failed to decrypt: ", err) + } + + if !bytes.Equal(plainText1, plainText) { + t.Fatalf("exp:%x got:%x", plainText1, plainText) + } +} + func doEncrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) { t.Helper() diff --git a/examples/sm4/main.go b/examples/sm4/main.go index d7b5c54..bd30437 100644 --- a/examples/sm4/main.go +++ b/examples/sm4/main.go @@ -9,11 +9,11 @@ package main import ( "bytes" + "crypto/cipher" "encoding/hex" "fmt" "log" - "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sm4" ) @@ -23,20 +23,18 @@ func sm4CBCEncrypt() { plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - enc, err := sm4.NewEncrypter(crypto.CipherModeCBC, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create encrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - enc.SetPadding(false) + cipherText1 := make([]byte, len(plainText)) - actualCipherText, err := enc.EncryptAll(plainText) - if err != nil { - log.Fatal("failed to encrypt: ", err) - } + stream := cipher.NewCBCEncrypter(block, iv) + stream.CryptBlocks(cipherText1, plainText) - if !bytes.Equal(cipherText, actualCipherText) { - log.Fatalf("exp:%x got:%x", cipherText, actualCipherText) + if !bytes.Equal(cipherText1, cipherText) { + log.Fatalf("exp:%x got:%x", cipherText, cipherText1) } fmt.Println("[sm4CBCEncrypt]") @@ -52,20 +50,18 @@ func sm4CBCDecrypt() { plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - enc, err := sm4.NewDecrypter(crypto.CipherModeCBC, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create decrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - enc.SetPadding(false) + plainText1 := make([]byte, len(cipherText)) - actualPlainText, err := enc.DecryptAll(cipherText) - if err != nil { - log.Fatal("failed to decrypt: ", err) - } + stream := cipher.NewCBCDecrypter(block, iv) + stream.CryptBlocks(plainText1, cipherText) - if !bytes.Equal(plainText, actualPlainText) { - log.Fatalf("exp:%x got:%x", plainText, actualPlainText) + if !bytes.Equal(plainText, plainText1) { + log.Fatalf("exp:%x got:%x", plainText, plainText1) } fmt.Println("[sm4CBCDecrypt]") @@ -83,29 +79,20 @@ func sm4GCMEncrypt() { plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") - enc, err := sm4.NewEncrypter(crypto.CipherModeGCM, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create encrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - enc.SetAAD(aad) - - actualCipherText, err := enc.EncryptAll(plainText) + stream, err := cipher.NewGCM(block) if err != nil { - log.Fatal("failed to encrypt: ", err) + log.Fatal("failed to create GCM: ", err) } - if !bytes.Equal(cipherText, actualCipherText) { - log.Fatalf("exp:%x got:%x", cipherText, actualCipherText) - } - - actualTag, err := enc.GetTag() - if err != nil { - log.Fatal("failed to get tag: ", err) - } + cipherText1 := stream.Seal(nil, iv, plainText, aad) - if !bytes.Equal(tag, actualTag) { - log.Fatalf("exp:%x got:%x", tag, actualTag) + if !bytes.Equal(cipherText1, append(cipherText, tag...)) { + log.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...)) } fmt.Println("[sm4GCMEncrypt]") @@ -125,21 +112,23 @@ func sm4GCMDecrypt() { plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") - dec, err := sm4.NewDecrypter(crypto.CipherModeGCM, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create decrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - dec.SetTag(tag) - dec.SetAAD(aad) + stream, err := cipher.NewGCM(block) + if err != nil { + log.Fatal("failed to create GCM: ", err) + } - actualPlainText, err := dec.DecryptAll(cipherText) + plainText1, err := stream.Open(nil, iv, append(cipherText, tag...), aad) if err != nil { log.Fatal("failed to decrypt: ", err) } - if !bytes.Equal(plainText, actualPlainText) { - log.Fatalf("exp:%x got:%x", plainText, actualPlainText) + if !bytes.Equal(plainText1, plainText) { + log.Fatalf("exp:%x got:%x", plainText1, plainText) } fmt.Println("[sm4GCMDecrypt]")