From 8f7d94a2f635e21e8aed22fe5d24528d07bf559c Mon Sep 17 00:00:00 2001 From: Eugene K Date: Fri, 14 Feb 2020 14:59:06 -0500 Subject: [PATCH] expose KeyPair.Public() key --- kx/kx.go | 28 +++++++++++++++++----------- kx/kx_test.go | 20 ++++++++++---------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/kx/kx.go b/kx/kx.go index 4c02819..2001a1c 100644 --- a/kx/kx.go +++ b/kx/kx.go @@ -18,8 +18,8 @@ var cryptoError = errors.New("crypto error") var notImplemented = errors.New("not implemented") type KeyPair struct { - pk []byte - sk []byte + pk [SessionKeyBytes]byte + sk [PublicKeyBytes]byte } func NewKeyPair() (*KeyPair, error) { @@ -39,24 +39,30 @@ func newKeyPairFromSeed(seed []byte) (*KeyPair, error) { hash, _ := blake2b.New(SecretKeyBytes, nil) hash.Write(seed) - kp.sk = hash.Sum(nil) - - if len(kp.sk) != SecretKeyBytes { + sk := hash.Sum(nil) + if len(sk) != SecretKeyBytes { return nil, cryptoError } + copy(kp.sk[:], sk) - kp.pk, err = curve25519.X25519(kp.sk, curve25519.Basepoint) + pk, err := curve25519.X25519(kp.sk[:], curve25519.Basepoint) if err != nil { return nil, err } - if len(kp.pk) != PublicKeyBytes { + if len(pk) != PublicKeyBytes { return nil, cryptoError } + copy(kp.pk[:], pk) + return kp, nil } +func (pair *KeyPair) Public() []byte { + return pair.pk[:] +} + func (pair *KeyPair) ClientSessionKeys(server_pk []byte) (rx []byte, tx []byte, err error) { - q, err := curve25519.X25519(pair.sk, server_pk) + q, err := curve25519.X25519(pair.sk[:], server_pk) if err != nil { return nil, nil, err } @@ -66,7 +72,7 @@ func (pair *KeyPair) ClientSessionKeys(server_pk []byte) (rx []byte, tx []byte, return nil, nil, err } - for _, b := range [][]byte{q, pair.pk, server_pk} { + for _, b := range [][]byte{q, pair.Public(), server_pk} { if _, err = h.Write(b); err != nil { return nil, nil, err } @@ -80,7 +86,7 @@ func (pair *KeyPair) ClientSessionKeys(server_pk []byte) (rx []byte, tx []byte, func (pair *KeyPair) ServerSessionKeys(client_pk []byte) (rx []byte, tx []byte, err error) { - q, err := curve25519.X25519(pair.sk, client_pk) + q, err := curve25519.X25519(pair.sk[:], client_pk) if err != nil { return nil, nil, err } @@ -90,7 +96,7 @@ func (pair *KeyPair) ServerSessionKeys(client_pk []byte) (rx []byte, tx []byte, return nil, nil, err } - for _, b := range [][]byte{q, client_pk, pair.pk} { + for _, b := range [][]byte{q, client_pk, pair.Public()} { if _, err = h.Write(b); err != nil { return nil, nil, err } diff --git a/kx/kx_test.go b/kx/kx_test.go index cbe4103..bcb1085 100644 --- a/kx/kx_test.go +++ b/kx/kx_test.go @@ -33,6 +33,9 @@ func seedIncrement(s []byte) []byte { func TestNewKeyPair(t *testing.T) { pk, _ := hex.DecodeString("0e0216223f147143d32615a91189c288c1728cba3cc5f9f621b1026e03d83129") sk, _ := hex.DecodeString("cb2f5160fc1f7e05a55ef49d340b48da2e5a78099d53393351cd579dd42503d6") + kp := &KeyPair{} + copy(kp.pk[:], pk) + copy(kp.sk[:], sk) type args struct { seed []byte @@ -45,12 +48,9 @@ func TestNewKeyPair(t *testing.T) { wantErr bool }{ { - name: "pre-seeded key", - args: args{seed: seed}, - want: &KeyPair{ - pk: pk, - sk: sk, - }, + name: "pre-seeded key", + args: args{seed: seed}, + want: kp, wantErr: false, }, } @@ -84,7 +84,7 @@ func TestKeyExchange_Seeded(t *testing.T) { clt_rx, _ := hex.DecodeString("749519c68059bce69f7cfcc7b387a3de1a1e8237d110991323bf62870115731a") clt_tx, _ := hex.DecodeString("62c8f4fa81800abd0577d99918d129b65deb789af8c8351f391feb0cbf238604") - client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.pk) + client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.Public()) if err != nil { t.Errorf("ClientSessionKeys: error = %v", err) return @@ -97,7 +97,7 @@ func TestKeyExchange_Seeded(t *testing.T) { t.Errorf("ClientSessionKeys(): TX got = %v, want %v", client_tx, clt_tx) } - server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.pk) + server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.Public()) if err != nil { t.Errorf("ServerSessionKeys: error = %v", err) return @@ -124,13 +124,13 @@ func TestKeyExchange(t *testing.T) { return } - client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.pk) + client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.Public()) if err != nil { t.Errorf("ClientSessionKeys: error = %v", err) return } - server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.pk) + server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.Public()) if err != nil { t.Errorf("ServerSessionKeys: error = %v", err) return