Skip to content

Commit

Permalink
Merge pull request #37 from dongbeiouba/refactor/sm4
Browse files Browse the repository at this point in the history
SM4 supports cipher.Block interface
  • Loading branch information
dongbeiouba authored Dec 11, 2024
2 parents 7c0ad4e + 9247fed commit c5bd3b9
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 47 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Build Tongsuo
run: |
cd Tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4
make -j4
make install
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
- name: Build Tongsuo
run: |
cd Tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4
make -j4
make install
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
- name: Build Tongsuo Static
run: |
cd tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls no-shared
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 no-shared
make -j4
make install
Expand Down Expand Up @@ -147,7 +147,7 @@ jobs:
run: |
mkdir _build
cd _build
perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls
perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls enable-export-sm4
nmake /S
nmake install
working-directory: Tongsuo
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cd Tongsuo

git checkout 8.3-stable

./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls
./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls enable-export-sm4
make -j
make install
```
Expand Down
1 change: 1 addition & 0 deletions crypto/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
ErrInternalError = errors.New("internal error")
ErrEmptyKey = errors.New("empty key")
ErrNoData = errors.New("no data")
ErrInvalidKeySize = errors.New("invalid key size")
)

func init() {
Expand Down
1 change: 1 addition & 0 deletions crypto/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <openssl/x509v3.h>
#include <openssl/ec.h>
#include <openssl/opensslv.h>
#include <openssl/sm4.h>

/* shim methods */
extern int X_tscrypto_init();
Expand Down
51 changes: 51 additions & 0 deletions crypto/sm4/sm4.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@ import "C"

import (
"bytes"
"crypto/cipher"
"fmt"
"unsafe"

"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
)

const (
BlockSize = 16
KeySize = 16
)

type Encrypter interface {
// crypto.EncryptionCipherCtx
SetPadding(pad bool)
Expand Down Expand Up @@ -50,6 +57,50 @@ type sm4Decrypter struct {
tag []byte
}

type sm4Cipher struct {
rk [32]uint32
}

func (c *sm4Cipher) BlockSize() int {
return BlockSize
}

func NewCipher(key []byte) (cipher.Block, error) {
if len(key) != KeySize {
return nil, fmt.Errorf("invalid key size: %w", crypto.ErrInvalidKeySize)
}

cipher := &sm4Cipher{}
ret := C.SM4_set_key((*C.uchar)(&key[0]), (*C.SM4_KEY)(unsafe.Pointer(&cipher.rk)))
if ret != 1 {
return nil, fmt.Errorf("failed to set key: %w", crypto.ErrInternalError)
}

return cipher, nil
}

func (c *sm4Cipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("sm4: input not full block")
}
if len(dst) < BlockSize {
panic("sm4: output not full block")
}

C.SM4_encrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk)))
}

func (c *sm4Cipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("sm4: input not full block")
}
if len(dst) < BlockSize {
panic("sm4: output not full block")
}

C.SM4_decrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk)))
}

func getSM4Cipher(mode int) (*crypto.Cipher, error) {
var cipher *crypto.Cipher
var err error
Expand Down
174 changes: 174 additions & 0 deletions crypto/sm4/sm4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package sm4_test

import (
"bytes"
"crypto/cipher"
"encoding/hex"
"strings"
"testing"
Expand All @@ -20,6 +21,179 @@ import (
const hexPlainText1 = `AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFE
EEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA`

func TestSM4ECBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("681EDF34D206965E86B3E94F536E4246")

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

cipherText1 := make([]byte, len(plainText))

for i := 0; i < len(plainText); i += block.BlockSize() {
block.Encrypt(cipherText1[i:i+block.BlockSize()], plainText[i:i+block.BlockSize()])
}

if !bytes.Equal(cipherText1, cipherText) {
t.Fatalf("exp:%x got:%x", cipherText, cipherText1)
}

plainText1 := make([]byte, len(cipherText1))

for i := 0; i < len(cipherText1); i += block.BlockSize() {
block.Decrypt(plainText1[i:i+block.BlockSize()], cipherText1[i:i+block.BlockSize()])
}

if !bytes.Equal(plainText, plainText1) {
t.Fatalf("exp:%x got:%x", plainText, plainText1)
}
}

func testCryptWithCipherBlock(t *testing.T, mode string, key, iv, plainText, cipherText []byte) {
t.Helper()

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

cipherText1 := make([]byte, len(plainText))

switch mode {
case "CBC":
stream := cipher.NewCBCEncrypter(block, iv)
stream.CryptBlocks(cipherText1, plainText)
case "CFB":
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(cipherText1, plainText)
case "OFB":
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(cipherText1, plainText)
case "CTR":
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(cipherText1, plainText)
}

if !bytes.Equal(cipherText1, cipherText) {
t.Fatalf("exp:%x got:%x", cipherText, cipherText1)
}

plainText1 := make([]byte, len(plainText))

switch mode {
case "CBC":
stream := cipher.NewCBCDecrypter(block, iv)
stream.CryptBlocks(plainText1, cipherText1)
case "CFB":
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
case "OFB":
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
case "CTR":
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
}

if !bytes.Equal(plainText, plainText1) {
t.Fatalf("exp:%x got:%x", plainText, plainText1)
}
}

func TestSM4CBCWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B")

testCryptWithCipherBlock(t, "CBC", key, iv, plainText, cipherText)
}

func TestSM4CFBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A70569ED258A85A0467CC92AAB393DD978995")

testCryptWithCipherBlock(t, "CFB", key, iv, plainText, cipherText)
}

func TestSM4OFBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A7056F2075D28B5235F58D50027E4177D2BCE")

testCryptWithCipherBlock(t, "OFB", key, iv, plainText, cipherText)
}

func TestSM4CTRWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
hexCipherText := `C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638
B361FD5FFACD942F081485A83CA35D`
plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", ""))
cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", ""))

testCryptWithCipherBlock(t, "CTR", key, iv, plainText, cipherText)
}

func TestSM4GCMWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("00001234567800000000ABCD")
aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2")
tag, _ := hex.DecodeString("83DE3541E4C2B58177E065A9BF7B62EC")
hexCipherText := `17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A5
6834CBCF98C397B4024A2691233B8D`
plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", ""))
cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", ""))

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

stream, err := cipher.NewGCM(block)
if err != nil {
t.Fatal("failed to create GCM: ", err)
}

cipherText1 := stream.Seal(nil, iv, plainText, aad)

if !bytes.Equal(cipherText1, append(cipherText, tag...)) {
t.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...))
}

stream2, err := cipher.NewGCM(block)
if err != nil {
t.Fatal("failed to create GCM: ", err)
}

plainText1, err := stream2.Open(nil, iv, cipherText1, aad)
if err != nil {
t.Fatal("failed to decrypt: ", err)
}

if !bytes.Equal(plainText1, plainText) {
t.Fatalf("exp:%x got:%x", plainText1, plainText)
}
}

func doEncrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) {
t.Helper()

Expand Down
Loading

0 comments on commit c5bd3b9

Please sign in to comment.