Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SM4 supports cipher.Block interface #37

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Build Tongsuo
run: |
cd Tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4
make -j4
make install

Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
- name: Build Tongsuo
run: |
cd Tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4
make -j4
make install

Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
- name: Build Tongsuo Static
run: |
cd tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls no-shared
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 no-shared
make -j4
make install

Expand Down Expand Up @@ -147,7 +147,7 @@ jobs:
run: |
mkdir _build
cd _build
perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls
perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls enable-export-sm4
nmake /S
nmake install
working-directory: Tongsuo
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cd Tongsuo

git checkout 8.3-stable

./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls
./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls enable-export-sm4
make -j
make install
```
Expand Down
1 change: 1 addition & 0 deletions crypto/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
ErrInternalError = errors.New("internal error")
ErrEmptyKey = errors.New("empty key")
ErrNoData = errors.New("no data")
ErrInvalidKeySize = errors.New("invalid key size")
)

func init() {
Expand Down
1 change: 1 addition & 0 deletions crypto/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <openssl/x509v3.h>
#include <openssl/ec.h>
#include <openssl/opensslv.h>
#include <openssl/sm4.h>

/* shim methods */
extern int X_tscrypto_init();
Expand Down
51 changes: 51 additions & 0 deletions crypto/sm4/sm4.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@ import "C"

import (
"bytes"
"crypto/cipher"
"fmt"
"unsafe"

"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
)

const (
BlockSize = 16
KeySize = 16
)

type Encrypter interface {
// crypto.EncryptionCipherCtx
SetPadding(pad bool)
Expand Down Expand Up @@ -50,6 +57,50 @@ type sm4Decrypter struct {
tag []byte
}

type sm4Cipher struct {
rk [32]uint32
}

func (c *sm4Cipher) BlockSize() int {
return BlockSize
}

func NewCipher(key []byte) (cipher.Block, error) {
if len(key) != KeySize {
return nil, fmt.Errorf("invalid key size: %w", crypto.ErrInvalidKeySize)
}

cipher := &sm4Cipher{}
ret := C.SM4_set_key((*C.uchar)(&key[0]), (*C.SM4_KEY)(unsafe.Pointer(&cipher.rk)))
if ret != 1 {
return nil, fmt.Errorf("failed to set key: %w", crypto.ErrInternalError)
}

return cipher, nil
}

func (c *sm4Cipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("sm4: input not full block")
}
if len(dst) < BlockSize {
panic("sm4: output not full block")
}

C.SM4_encrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk)))
}

func (c *sm4Cipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("sm4: input not full block")
}
if len(dst) < BlockSize {
panic("sm4: output not full block")
}

C.SM4_decrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk)))
}

func getSM4Cipher(mode int) (*crypto.Cipher, error) {
var cipher *crypto.Cipher
var err error
Expand Down
174 changes: 174 additions & 0 deletions crypto/sm4/sm4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package sm4_test

import (
"bytes"
"crypto/cipher"
"encoding/hex"
"strings"
"testing"
Expand All @@ -20,6 +21,179 @@ import (
const hexPlainText1 = `AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFE
EEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA`

func TestSM4ECBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("681EDF34D206965E86B3E94F536E4246")

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

cipherText1 := make([]byte, len(plainText))

for i := 0; i < len(plainText); i += block.BlockSize() {
block.Encrypt(cipherText1[i:i+block.BlockSize()], plainText[i:i+block.BlockSize()])
}

if !bytes.Equal(cipherText1, cipherText) {
t.Fatalf("exp:%x got:%x", cipherText, cipherText1)
}

plainText1 := make([]byte, len(cipherText1))

for i := 0; i < len(cipherText1); i += block.BlockSize() {
block.Decrypt(plainText1[i:i+block.BlockSize()], cipherText1[i:i+block.BlockSize()])
}

if !bytes.Equal(plainText, plainText1) {
t.Fatalf("exp:%x got:%x", plainText, plainText1)
}
}

func testCryptWithCipherBlock(t *testing.T, mode string, key, iv, plainText, cipherText []byte) {
t.Helper()

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

cipherText1 := make([]byte, len(plainText))

switch mode {
case "CBC":
stream := cipher.NewCBCEncrypter(block, iv)
stream.CryptBlocks(cipherText1, plainText)
case "CFB":
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(cipherText1, plainText)
case "OFB":
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(cipherText1, plainText)
case "CTR":
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(cipherText1, plainText)
}

if !bytes.Equal(cipherText1, cipherText) {
t.Fatalf("exp:%x got:%x", cipherText, cipherText1)
}

plainText1 := make([]byte, len(plainText))

switch mode {
case "CBC":
stream := cipher.NewCBCDecrypter(block, iv)
stream.CryptBlocks(plainText1, cipherText1)
case "CFB":
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
case "OFB":
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
case "CTR":
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
}

if !bytes.Equal(plainText, plainText1) {
t.Fatalf("exp:%x got:%x", plainText, plainText1)
}
}

func TestSM4CBCWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B")

testCryptWithCipherBlock(t, "CBC", key, iv, plainText, cipherText)
}

func TestSM4CFBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A70569ED258A85A0467CC92AAB393DD978995")

testCryptWithCipherBlock(t, "CFB", key, iv, plainText, cipherText)
}

func TestSM4OFBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A7056F2075D28B5235F58D50027E4177D2BCE")

testCryptWithCipherBlock(t, "OFB", key, iv, plainText, cipherText)
}

func TestSM4CTRWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
hexCipherText := `C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638
B361FD5FFACD942F081485A83CA35D`
plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", ""))
cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", ""))

testCryptWithCipherBlock(t, "CTR", key, iv, plainText, cipherText)
}

func TestSM4GCMWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("00001234567800000000ABCD")
aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2")
tag, _ := hex.DecodeString("83DE3541E4C2B58177E065A9BF7B62EC")
hexCipherText := `17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A5
6834CBCF98C397B4024A2691233B8D`
plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", ""))
cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", ""))

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

stream, err := cipher.NewGCM(block)
if err != nil {
t.Fatal("failed to create GCM: ", err)
}

cipherText1 := stream.Seal(nil, iv, plainText, aad)

if !bytes.Equal(cipherText1, append(cipherText, tag...)) {
t.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...))
}

stream2, err := cipher.NewGCM(block)
if err != nil {
t.Fatal("failed to create GCM: ", err)
}

plainText1, err := stream2.Open(nil, iv, cipherText1, aad)
if err != nil {
t.Fatal("failed to decrypt: ", err)
}

if !bytes.Equal(plainText1, plainText) {
t.Fatalf("exp:%x got:%x", plainText1, plainText)
}
}

func doEncrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) {
t.Helper()

Expand Down
Loading
Loading