diff --git a/key/key_test.go b/key/key_test.go index d2b3104..e84d99b 100644 --- a/key/key_test.go +++ b/key/key_test.go @@ -27,9 +27,19 @@ func TestPublicKey(t *testing.T) { require.Equal(t, "", hex.EncodeToString(b.Bytes())) }) - t.Run("invalid key", func(t *testing.T) { - _, err := NewPublicKey("123") - require.EqualError(t, err, "wrong prefix") + t.Run("wrong key len", func(t *testing.T) { + _, err := NewPublicKey("") + require.ErrorIs(t, err, ErrWrongKeyLen) + }) + + t.Run("wrong prefix", func(t *testing.T) { + _, err := NewPublicKey("SSS11111111111111111111111111111111111111111111111111") + require.ErrorIs(t, err, ErrWrongPrefix) + }) + + t.Run("wrong check sum", func(t *testing.T) { + _, err := NewPublicKey("SCR11111111111111111111111111111111111111111111111111") + require.ErrorIs(t, err, ErrWrongChecksum) }) t.Run("public key from string", func(t *testing.T) { @@ -89,15 +99,13 @@ func TestSignAndValidate(t *testing.T) { sig := pk.Sign(hash) require.Equal(t, sigHex, hex.EncodeToString(sig)) - ok, err := pk.PublicKey().Verify(hash, sig) + err = pk.PublicKey().Verify(hash, sig) require.NoError(t, err) - require.True(t, ok) - t.Run("validate with wrong key return false", func(t *testing.T) { + t.Run("validate with wrong key key mismatch", func(t *testing.T) { wrogKey, err := NewPublicKey("SCR5jPZF7PMgTpLqkdfpMu8kXea8Gio6E646aYpTgcjr9qMLrAgnL") require.NoError(t, err) - ok, err := wrogKey.Verify(hash, sig) - require.NoError(t, err) - require.False(t, ok) + err = wrogKey.Verify(hash, sig) + require.ErrorIs(t, err, ErrKeyMismatch) }) } diff --git a/key/publickey.go b/key/publickey.go index a70bafa..de14858 100644 --- a/key/publickey.go +++ b/key/publickey.go @@ -3,6 +3,7 @@ package key import ( "bytes" "errors" + "fmt" "strings" "github.com/scorum/scorum-go/encoding/transaction" @@ -15,11 +16,14 @@ import ( const ( publicKeyPrefix = "SCR" checkSumLen = 4 + keyLen = 53 ) var ( ErrWrongPrefix = errors.New("wrong prefix") ErrWrongChecksum = errors.New("wrong check sum") + ErrWrongKeyLen = errors.New("wrong key len") + ErrKeyMismatch = errors.New("key mismatch") ) type PublicKey struct { @@ -27,10 +31,16 @@ type PublicKey struct { } func NewPublicKey(pubKey string) (*PublicKey, error) { + if len(pubKey) != keyLen { + return nil, ErrWrongKeyLen + } + if !strings.HasPrefix(pubKey, publicKeyPrefix) { return nil, ErrWrongPrefix } + keyWithChecksum := base58.Decode(pubKey[len(publicKeyPrefix):]) + key := keyWithChecksum[:len(keyWithChecksum)-checkSumLen] h := ripemd160.New() @@ -69,13 +79,17 @@ func (p *PublicKey) Serialize() []byte { return p.raw.SerializeCompressed() } -func (p *PublicKey) Verify(digest []byte, signature []byte) (bool, error) { +func (p *PublicKey) Verify(digest []byte, signature []byte) error { pub, _, err := btcec.RecoverCompact(btcec.S256(), signature, digest) if err != nil { - return false, err + return fmt.Errorf("recover compact: %w", err) + } + + if !p.raw.IsEqual(pub) { + return ErrKeyMismatch } - return p.raw.IsEqual(pub), nil + return nil } func (p *PublicKey) MarshalTransaction(encoder *transaction.Encoder) error { diff --git a/sign/signed_transaction.go b/sign/signed_transaction.go index 57e3528..93a2d65 100644 --- a/sign/signed_transaction.go +++ b/sign/signed_transaction.go @@ -67,24 +67,24 @@ func (tx *SignedTransaction) Sign(chainID []byte, keys ...*key.PrivateKey) error return nil } -func (tx *SignedTransaction) Verify(chainID []byte, keys ...*key.PublicKey) (bool, error) { +func (tx *SignedTransaction) Verify(chainID []byte, keys ...*key.PublicKey) error { dig, err := tx.Digest(chainID) if err != nil { - return false, fmt.Errorf("failed to get digest: %w", err) + return fmt.Errorf("failed to get digest: %w", err) } for _, signature := range tx.Signatures { sig, err := hex.DecodeString(signature) if err != nil { - return false, fmt.Errorf("failed to decode signature: %w", err) + return fmt.Errorf("failed to decode signature: %w", err) } for _, k := range keys { - if ok, err := k.Verify(dig, sig); err != nil || !ok { - return ok, fmt.Errorf("verify signature: %w", err) + if err := k.Verify(dig, sig); err != nil { + return fmt.Errorf("verify signature: %w", err) } } } - return true, nil + return nil } diff --git a/sign/signed_transaction_test.go b/sign/signed_transaction_test.go index 9647f60..b07b479 100644 --- a/sign/signed_transaction_test.go +++ b/sign/signed_transaction_test.go @@ -72,7 +72,5 @@ func TestTransaction_Verify(t *testing.T) { pubKey, err := key.NewPublicKey("SCR7cTf2Dx9rxffs6E2z2pdn5cLMneo3AAFSsF9g4SaVviCYdfQ63") require.NoError(t, err) - res, err := stx.Verify(TestNetChainID, pubKey) - require.NoError(t, err) - require.True(t, res) + require.NoError(t, stx.Verify(TestNetChainID, pubKey)) } diff --git a/types/key_test.go b/types/key_test.go index 20bbc5d..d82f453 100644 --- a/types/key_test.go +++ b/types/key_test.go @@ -11,27 +11,12 @@ import ( ) func TestPublicKey_MarshalTransaction(t *testing.T) { - t.Run("empty key", func(t *testing.T) { - var key PublicKey - var b bytes.Buffer - encoder := transaction.NewEncoder(&b) - require.EqualError(t, key.MarshalTransaction(encoder), "wrong prefix") - require.Equal(t, "", hex.EncodeToString(b.Bytes())) - }) + var ( + b bytes.Buffer + key PublicKey = "SCR5jPZF7PMgTpLqkdfpMu8kXea8Gio6E646aYpTgcjr9qMLrAgnL" + encoder = transaction.NewEncoder(&b) + ) - t.Run("invalid key", func(t *testing.T) { - key := PublicKey("123") - var b bytes.Buffer - encoder := transaction.NewEncoder(&b) - require.EqualError(t, key.MarshalTransaction(encoder), "wrong prefix") - require.Equal(t, "", hex.EncodeToString(b.Bytes())) - }) - - t.Run("valid key", func(t *testing.T) { - key := PublicKey("SCR5jPZF7PMgTpLqkdfpMu8kXea8Gio6E646aYpTgcjr9qMLrAgnL") - var b bytes.Buffer - encoder := transaction.NewEncoder(&b) - require.NoError(t, key.MarshalTransaction(encoder)) - require.Equal(t, "026f0896f24d94252c351715bfe6052bbf9ea820e805bd47c2496c626d3467da5d", hex.EncodeToString(b.Bytes())) - }) + require.NoError(t, key.MarshalTransaction(encoder)) + require.Equal(t, "026f0896f24d94252c351715bfe6052bbf9ea820e805bd47c2496c626d3467da5d", hex.EncodeToString(b.Bytes())) }