Skip to content

Commit

Permalink
simplify PublicKey.Verify, handle errors on NewPublicKey
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Plevako committed Apr 28, 2022
1 parent e61a307 commit 3625018
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 43 deletions.
26 changes: 17 additions & 9 deletions key/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,19 @@ func TestPublicKey(t *testing.T) {
require.Equal(t, "", hex.EncodeToString(b.Bytes()))
})

t.Run("invalid key", func(t *testing.T) {
_, err := NewPublicKey("123")
require.EqualError(t, err, "wrong prefix")
t.Run("wrong key len", func(t *testing.T) {
_, err := NewPublicKey("")
require.ErrorIs(t, err, ErrWrongKeyLen)
})

t.Run("wrong prefix", func(t *testing.T) {
_, err := NewPublicKey("SSS11111111111111111111111111111111111111111111111111")
require.ErrorIs(t, err, ErrWrongPrefix)
})

t.Run("wrong check sum", func(t *testing.T) {
_, err := NewPublicKey("SCR11111111111111111111111111111111111111111111111111")
require.ErrorIs(t, err, ErrWrongChecksum)
})

t.Run("public key from string", func(t *testing.T) {
Expand Down Expand Up @@ -89,15 +99,13 @@ func TestSignAndValidate(t *testing.T) {
sig := pk.Sign(hash)
require.Equal(t, sigHex, hex.EncodeToString(sig))

ok, err := pk.PublicKey().Verify(hash, sig)
err = pk.PublicKey().Verify(hash, sig)
require.NoError(t, err)
require.True(t, ok)

t.Run("validate with wrong key return false", func(t *testing.T) {
t.Run("validate with wrong key key mismatch", func(t *testing.T) {
wrogKey, err := NewPublicKey("SCR5jPZF7PMgTpLqkdfpMu8kXea8Gio6E646aYpTgcjr9qMLrAgnL")
require.NoError(t, err)
ok, err := wrogKey.Verify(hash, sig)
require.NoError(t, err)
require.False(t, ok)
err = wrogKey.Verify(hash, sig)
require.ErrorIs(t, err, ErrKeyMismatch)
})
}
20 changes: 17 additions & 3 deletions key/publickey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package key
import (
"bytes"
"errors"
"fmt"
"strings"

"github.com/scorum/scorum-go/encoding/transaction"
Expand All @@ -15,22 +16,31 @@ import (
const (
publicKeyPrefix = "SCR"
checkSumLen = 4
keyLen = 53
)

var (
ErrWrongPrefix = errors.New("wrong prefix")
ErrWrongChecksum = errors.New("wrong check sum")
ErrWrongKeyLen = errors.New("wrong key len")
ErrKeyMismatch = errors.New("key mismatch")
)

type PublicKey struct {
raw *btcec.PublicKey
}

func NewPublicKey(pubKey string) (*PublicKey, error) {
if len(pubKey) != keyLen {
return nil, ErrWrongKeyLen
}

if !strings.HasPrefix(pubKey, publicKeyPrefix) {
return nil, ErrWrongPrefix
}

keyWithChecksum := base58.Decode(pubKey[len(publicKeyPrefix):])

key := keyWithChecksum[:len(keyWithChecksum)-checkSumLen]

h := ripemd160.New()
Expand Down Expand Up @@ -69,13 +79,17 @@ func (p *PublicKey) Serialize() []byte {
return p.raw.SerializeCompressed()
}

func (p *PublicKey) Verify(digest []byte, signature []byte) (bool, error) {
func (p *PublicKey) Verify(digest []byte, signature []byte) error {
pub, _, err := btcec.RecoverCompact(btcec.S256(), signature, digest)
if err != nil {
return false, err
return fmt.Errorf("recover compact: %w", err)
}

if !p.raw.IsEqual(pub) {
return ErrKeyMismatch
}

return p.raw.IsEqual(pub), nil
return nil
}

func (p *PublicKey) MarshalTransaction(encoder *transaction.Encoder) error {
Expand Down
12 changes: 6 additions & 6 deletions sign/signed_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,24 @@ func (tx *SignedTransaction) Sign(chainID []byte, keys ...*key.PrivateKey) error
return nil
}

func (tx *SignedTransaction) Verify(chainID []byte, keys ...*key.PublicKey) (bool, error) {
func (tx *SignedTransaction) Verify(chainID []byte, keys ...*key.PublicKey) error {
dig, err := tx.Digest(chainID)
if err != nil {
return false, fmt.Errorf("failed to get digest: %w", err)
return fmt.Errorf("failed to get digest: %w", err)
}

for _, signature := range tx.Signatures {
sig, err := hex.DecodeString(signature)
if err != nil {
return false, fmt.Errorf("failed to decode signature: %w", err)
return fmt.Errorf("failed to decode signature: %w", err)
}

for _, k := range keys {
if ok, err := k.Verify(dig, sig); err != nil || !ok {
return ok, fmt.Errorf("verify signature: %w", err)
if err := k.Verify(dig, sig); err != nil {
return fmt.Errorf("verify signature: %w", err)
}
}
}

return true, nil
return nil
}
4 changes: 1 addition & 3 deletions sign/signed_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,5 @@ func TestTransaction_Verify(t *testing.T) {
pubKey, err := key.NewPublicKey("SCR7cTf2Dx9rxffs6E2z2pdn5cLMneo3AAFSsF9g4SaVviCYdfQ63")
require.NoError(t, err)

res, err := stx.Verify(TestNetChainID, pubKey)
require.NoError(t, err)
require.True(t, res)
require.NoError(t, stx.Verify(TestNetChainID, pubKey))
}
29 changes: 7 additions & 22 deletions types/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,12 @@ import (
)

func TestPublicKey_MarshalTransaction(t *testing.T) {
t.Run("empty key", func(t *testing.T) {
var key PublicKey
var b bytes.Buffer
encoder := transaction.NewEncoder(&b)
require.EqualError(t, key.MarshalTransaction(encoder), "wrong prefix")
require.Equal(t, "", hex.EncodeToString(b.Bytes()))
})
var (
b bytes.Buffer
key PublicKey = "SCR5jPZF7PMgTpLqkdfpMu8kXea8Gio6E646aYpTgcjr9qMLrAgnL"
encoder = transaction.NewEncoder(&b)
)

t.Run("invalid key", func(t *testing.T) {
key := PublicKey("123")
var b bytes.Buffer
encoder := transaction.NewEncoder(&b)
require.EqualError(t, key.MarshalTransaction(encoder), "wrong prefix")
require.Equal(t, "", hex.EncodeToString(b.Bytes()))
})

t.Run("valid key", func(t *testing.T) {
key := PublicKey("SCR5jPZF7PMgTpLqkdfpMu8kXea8Gio6E646aYpTgcjr9qMLrAgnL")
var b bytes.Buffer
encoder := transaction.NewEncoder(&b)
require.NoError(t, key.MarshalTransaction(encoder))
require.Equal(t, "026f0896f24d94252c351715bfe6052bbf9ea820e805bd47c2496c626d3467da5d", hex.EncodeToString(b.Bytes()))
})
require.NoError(t, key.MarshalTransaction(encoder))
require.Equal(t, "026f0896f24d94252c351715bfe6052bbf9ea820e805bd47c2496c626d3467da5d", hex.EncodeToString(b.Bytes()))
}

0 comments on commit 3625018

Please sign in to comment.