From 284cebd8a9a365af18ae649f841681c472ffd75d Mon Sep 17 00:00:00 2001 From: Minh Huy Tran Date: Fri, 16 Feb 2024 11:47:33 +0100 Subject: [PATCH] feature(wire/net/simple & backend/sim/wire): change wire implementation to fit new interface + add tests for authentification Signed-off-by: Minh Huy Tran --- backend/sim/wire/account.go | 5 + backend/sim/wire/address.go | 9 ++ go.mod | 3 +- go.sum | 9 +- wire/net/simple/account.go | 39 ++++-- wire/net/simple/address.go | 116 +++++++++++++++--- wire/net/simple/dialer_internal_test.go | 108 +++++++++++----- wire/net/simple/mockconn_internal_test.go | 68 ++++++++++ .../simple_exchange_addr_internal_test.go | 107 ++++++++++++++++ 9 files changed, 400 insertions(+), 64 deletions(-) create mode 100644 wire/net/simple/mockconn_internal_test.go create mode 100644 wire/net/simple/simple_exchange_addr_internal_test.go diff --git a/backend/sim/wire/account.go b/backend/sim/wire/account.go index ec5facba9..191cbcd17 100644 --- a/backend/sim/wire/account.go +++ b/backend/sim/wire/account.go @@ -30,6 +30,11 @@ func (acc *Account) Address() wire.Address { return acc.addr } +// Sign signs the given message with the account's private key. +func (acc *Account) Sign(msg []byte) ([]byte, error) { + return []byte("Authenticate"), nil +} + // NewRandomAccount generates a new random account. func NewRandomAccount(rng *rand.Rand) *Account { return &Account{ diff --git a/backend/sim/wire/address.go b/backend/sim/wire/address.go index 5db6d7d2d..4c241dc5c 100644 --- a/backend/sim/wire/address.go +++ b/backend/sim/wire/address.go @@ -16,6 +16,7 @@ package wire import ( "bytes" + "errors" "math/rand" "perun.network/go-perun/wire" @@ -62,6 +63,14 @@ func (a Address) Cmp(b wire.Address) int { return bytes.Compare(a[:], bTyped[:]) } +// Verify verifies a signature. +func (a Address) Verify(msg, sig []byte) error { + if !bytes.Equal(sig, []byte("Authenticate")) { + return errors.New("invalid signature") + } + return nil +} + // NewRandomAddress returns a new random peer address. func NewRandomAddress(rng *rand.Rand) *Address { addr := Address{} diff --git a/go.mod b/go.mod index 5d66b4575..bb771df5c 100644 --- a/go.mod +++ b/go.mod @@ -9,14 +9,13 @@ require ( go.uber.org/goleak v1.1.11 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - google.golang.org/protobuf v1.23.0 + google.golang.org/protobuf v1.32.0 polycry.pt/poly-go v0.0.0-20220222131629-aa4bdbaab60b ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/go-cmp v0.5.4 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.1.1 // indirect diff --git a/go.sum b/go.sum index 3e51d0b9f..353017516 100644 --- a/go.sum +++ b/go.sum @@ -12,13 +12,14 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= -github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= @@ -114,8 +115,10 @@ google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/wire/net/simple/account.go b/wire/net/simple/account.go index b1999cf10..64f9f642d 100644 --- a/wire/net/simple/account.go +++ b/wire/net/simple/account.go @@ -15,21 +15,20 @@ package simple import ( + "crypto" + crypto_rand "crypto/rand" + "crypto/rsa" + "crypto/sha256" "math/rand" + "github.com/pkg/errors" "perun.network/go-perun/wire" ) // Account is a wire account. type Account struct { - addr wire.Address -} - -// NewAccount creates a new account. -func NewAccount(addr *Address) *Account { - return &Account{ - addr: addr, - } + addr wire.Address + privateKey *rsa.PrivateKey } // Address returns the account's address. @@ -37,9 +36,31 @@ func (acc *Account) Address() wire.Address { return acc.addr } +// Sign signs the given message with the account's private key. +func (acc *Account) Sign(msg []byte) ([]byte, error) { + if acc.privateKey == nil { + return nil, errors.New("private key is nil") + } + hashed := sha256.Sum256([]byte(msg)) + signature, err := rsa.SignPKCS1v15(crypto_rand.Reader, acc.privateKey, crypto.SHA256, hashed[:]) + if err != nil { + return nil, err + } + return signature, nil +} + // NewRandomAccount generates a new random account. func NewRandomAccount(rng *rand.Rand) *Account { + privateKey, err := rsa.GenerateKey(rng, 2048) + if err != nil { + panic(err) + } + + address := NewRandomAddress(rng) + address.PublicKey = &privateKey.PublicKey + return &Account{ - addr: NewRandomAddress(rng), + addr: address, + privateKey: privateKey, } } diff --git a/wire/net/simple/address.go b/wire/net/simple/address.go index 3cc962df2..6706a86fe 100644 --- a/wire/net/simple/address.go +++ b/wire/net/simple/address.go @@ -16,54 +16,120 @@ package simple import ( "bytes" + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/binary" + "encoding/gob" "math/rand" "perun.network/go-perun/wire" ) // Address is a wire address. -type Address string +type Address struct { + Name string + PublicKey *rsa.PublicKey // Public key for verifying signatures -var _ wire.Address = NewAddress("") +} + +var _ wire.Address = (*Address)(nil) // NewAddress returns a new address. func NewAddress(host string) *Address { - a := Address(host) - return &a + return &Address{ + Name: host, + PublicKey: nil, + } } // MarshalBinary marshals the address to binary. -func (a Address) MarshalBinary() ([]byte, error) { - buf := make([]byte, len(a)) - copy(buf, []byte(a)) - return buf, nil +func (a *Address) MarshalBinary() ([]byte, error) { + // Initialize a buffer to hold the binary data + var buf bytes.Buffer + + // Encode the length of the name string and the name itself + nameLen := uint16(len(a.Name)) + if err := binary.Write(&buf, binary.BigEndian, nameLen); err != nil { + return nil, err + } + if _, err := buf.WriteString(a.Name); err != nil { + return nil, err + } + + // If the public key is not nil, encode it using gob + if a.PublicKey != nil { + enc := gob.NewEncoder(&buf) + if err := enc.Encode(a.PublicKey); err != nil { + return nil, err + } + } + + // Return the binary representation + return buf.Bytes(), nil } // UnmarshalBinary unmarshals an address from binary. func (a *Address) UnmarshalBinary(data []byte) error { - buf := make([]byte, len(data)) - copy(buf, data) - *a = Address(buf) + // Initialize a buffer with the binary data + buf := bytes.NewReader(data) + + // Decode the length of the name string + var nameLen uint16 + if err := binary.Read(buf, binary.BigEndian, &nameLen); err != nil { + return err + } + + // Read the name string from the buffer + nameBytes := make([]byte, nameLen) + if _, err := buf.Read(nameBytes); err != nil { + return err + } + a.Name = string(nameBytes) + + // If there's remaining data, decode the public key using gob + if buf.Len() > 0 { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&a.PublicKey); err != nil { + return err + } + } + return nil } // Equal returns whether the two addresses are equal. -func (a Address) Equal(b wire.Address) bool { +func (a *Address) Equal(b wire.Address) bool { bTyped, ok := b.(*Address) if !ok { return false } - return a == *bTyped + if a.PublicKey == nil { + return a.Name == bTyped.Name && bTyped.PublicKey == nil + } + + return a.Name == bTyped.Name && a.PublicKey.Equal(bTyped.PublicKey) } -// Cmp compares the byte representation of two addresses. For `a.Cmp(b)` -// returns -1 if a < b, 0 if a == b, 1 if a > b. -func (a Address) Cmp(b wire.Address) int { +// Cmp compares the byte representation of two addresses. +func (a *Address) Cmp(b wire.Address) int { bTyped, ok := b.(*Address) if !ok { panic("wrong type") } - return bytes.Compare([]byte(a), []byte(*bTyped)) + if cmp := bytes.Compare([]byte(a.Name), []byte(bTyped.Name)); cmp != 0 { + return cmp + } + + bytesA, err := a.MarshalBinary() + if err != nil { + panic(err) + } + bytesB, err := bTyped.MarshalBinary() + if err != nil { + panic(err) + } + return bytes.Compare(bytesA, bytesB) } // NewRandomAddress returns a new random peer address. @@ -75,6 +141,18 @@ func NewRandomAddress(rng *rand.Rand) *Address { panic(err) } - a := Address(d) - return &a + a := &Address{ + Name: string(d), + } + return a +} + +// Verify verifies a message signature. +func (a *Address) Verify(msg []byte, sig []byte) error { + hashed := sha256.Sum256(msg) + err := rsa.VerifyPKCS1v15(a.PublicKey, crypto.SHA256, hashed[:], sig) + if err != nil { + return err + } + return nil } diff --git a/wire/net/simple/dialer_internal_test.go b/wire/net/simple/dialer_internal_test.go index 7e5c8d752..5d6f360af 100644 --- a/wire/net/simple/dialer_internal_test.go +++ b/wire/net/simple/dialer_internal_test.go @@ -33,7 +33,6 @@ import ( "perun.network/go-perun/wire" perunio "perun.network/go-perun/wire/perunio/serializer" - wiretest "perun.network/go-perun/wire/test" ctxtest "polycry.pt/poly-go/context/test" "polycry.pt/poly-go/test" ) @@ -53,7 +52,7 @@ func TestNewUnixDialer(t *testing.T) { func TestDialer_Register(t *testing.T) { tlsConfig := &tls.Config{} rng := test.Prng(t) - addr := wiretest.NewRandomAddress(rng) + addr := NewRandomAddress(rng) key := wire.Key(addr) d := NewTCPDialer(0, tlsConfig) @@ -71,7 +70,7 @@ func TestDialer_Dial(t *testing.T) { timeout := 100 * time.Millisecond rng := test.Prng(t) lhost := "127.0.0.1:7357" - laddr := wiretest.NewRandomAddress(rng) + laddr := NewRandomAccount(rng).Address() commonName := "127.0.0.1" sans := []string{"127.0.0.1", "localhost"} @@ -85,7 +84,7 @@ func TestDialer_Dial(t *testing.T) { ser := perunio.Serializer() d := NewTCPDialer(timeout, dConfig) d.Register(laddr, lhost) - daddr := wiretest.NewRandomAddress(rng) + daddr := NewRandomAccount(rng).Address() defer d.Close() t.Run("happy", func(t *testing.T) { @@ -129,7 +128,7 @@ func TestDialer_Dial(t *testing.T) { }) t.Run("unknown host", func(t *testing.T) { - noHostAddr := wiretest.NewRandomAddress(rng) + noHostAddr := NewRandomAddress(rng) d.Register(noHostAddr, "no such host") ctxtest.AssertTerminates(t, timeout, func() { @@ -141,7 +140,7 @@ func TestDialer_Dial(t *testing.T) { t.Run("unknown address", func(t *testing.T) { ctxtest.AssertTerminates(t, timeout, func() { - unkownAddr := wiretest.NewRandomAddress(rng) + unkownAddr := NewRandomAddress(rng) conn, err := d.Dial(context.Background(), unkownAddr, ser) assert.Error(t, err) assert.Nil(t, conn) @@ -152,14 +151,20 @@ func TestDialer_Dial(t *testing.T) { // generateSelfSignedCertConfigs generates a self-signed certificate and returns // the server and client TLS configurations. func generateSelfSignedCertConfigs(commonName string, sans []string) (*tls.Config, *tls.Config, error) { - // Generate a new RSA private key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + // Generate a new RSA private key for the server + serverPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, err } - // Create a certificate template - template := x509.Certificate{ + // Generate a new RSA private key for the client + clientPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + // Create a certificate template for the server + serverTemplate := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Perun Network"}, @@ -172,48 +177,89 @@ func generateSelfSignedCertConfigs(commonName string, sans []string) (*tls.Confi BasicConstraintsValid: true, } - // Add SANs to the certificate template + // Add SANs to the server certificate template for _, san := range sans { if ip := net.ParseIP(san); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) + serverTemplate.IPAddresses = append(serverTemplate.IPAddresses, ip) } else { - template.DNSNames = append(template.DNSNames, san) + serverTemplate.DNSNames = append(serverTemplate.DNSNames, san) } } - // Generate a self-signed certificate - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + // Generate a self-signed server certificate + serverCertDER, err := x509.CreateCertificate(rand.Reader, &serverTemplate, &serverTemplate, &serverPrivateKey.PublicKey, serverPrivateKey) if err != nil { return nil, nil, err } - // Encode the certificate to PEM format - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + // Encode the server certificate to PEM format + serverCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverCertDER}) - // Encode the private key to PEM format - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + // Encode the server private key to PEM format + serverKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(serverPrivateKey)}) - // Create a tls.Certificate object - cert, err := tls.X509KeyPair(certPEM, keyPEM) + // Create a tls.Certificate object for the server + serverCert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) if err != nil { return nil, nil, err } - // Server-side TLS configuration - lConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, + // Create a certificate template for the client + clientTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Perun Network"}, + CommonName: commonName, // Change this to the client's common name + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, // Set the client authentication usage + BasicConstraintsValid: true, } - // Client-side TLS configuration - roots := x509.NewCertPool() - ok := roots.AppendCertsFromPEM(certPEM) + // Generate a self-signed client certificate + clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &clientTemplate, &clientPrivateKey.PublicKey, serverPrivateKey) + if err != nil { + return nil, nil, err + } + + // Encode the client certificate to PEM format + clientCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}) + + // Encode the client private key to PEM format + clientKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientPrivateKey)}) + + // Create a tls.Certificate object for the client + clientCert, err := tls.X509KeyPair(clientCertPEM, clientKeyPEM) + if err != nil { + return nil, nil, err + } + + serverCertPool := x509.NewCertPool() + ok := serverCertPool.AppendCertsFromPEM(clientCertPEM) if !ok { return nil, nil, fmt.Errorf("failed to parse root certificate") } - dConfig := &tls.Config{ - RootCAs: roots, - ServerName: commonName, // Use the commonName as the ServerName for the client + + // Create the server-side TLS configuration + serverConfig := &tls.Config{ + ClientCAs: serverCertPool, + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + clientCertPool := x509.NewCertPool() + ok = clientCertPool.AppendCertsFromPEM(serverCertPEM) + if !ok { + return nil, nil, fmt.Errorf("failed to parse root certificate") + } + + // Create the client-side TLS configuration + clientConfig := &tls.Config{ + RootCAs: clientCertPool, + Certificates: []tls.Certificate{clientCert}, } - return lConfig, dConfig, nil + return serverConfig, clientConfig, nil } diff --git a/wire/net/simple/mockconn_internal_test.go b/wire/net/simple/mockconn_internal_test.go new file mode 100644 index 000000000..73350efcb --- /dev/null +++ b/wire/net/simple/mockconn_internal_test.go @@ -0,0 +1,68 @@ +// Copyright 2019 - See NOTICE file for copyright holders. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package simple + +import ( + "sync" + + "github.com/pkg/errors" + + "perun.network/go-perun/wire" + wirenet "perun.network/go-perun/wire/net" + "polycry.pt/poly-go/sync/atomic" +) + +var _ wirenet.Conn = (*MockConn)(nil) + +type MockConn struct { + mutex sync.Mutex + closed atomic.Bool + recvQueue chan *wire.Envelope + + sent func(*wire.Envelope) // observes sent messages. +} + +func newMockConn() *MockConn { + return &MockConn{ + sent: func(*wire.Envelope) {}, + recvQueue: make(chan *wire.Envelope, 1), + } +} + +func (c *MockConn) Send(e *wire.Envelope) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.closed.IsSet() { + return errors.New("closed") + } + c.sent(e) + return nil +} + +func (c *MockConn) Recv() (*wire.Envelope, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.closed.IsSet() { + return nil, errors.New("closed") + } + return <-c.recvQueue, nil +} + +func (c *MockConn) Close() error { + if !c.closed.TrySet() { + return errors.New("double close") + } + return nil +} diff --git a/wire/net/simple/simple_exchange_addr_internal_test.go b/wire/net/simple/simple_exchange_addr_internal_test.go new file mode 100644 index 000000000..051e27ede --- /dev/null +++ b/wire/net/simple/simple_exchange_addr_internal_test.go @@ -0,0 +1,107 @@ +// Copyright 2020 - See NOTICE file for copyright holders. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package simple + +import ( + "context" + "math/rand" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "perun.network/go-perun/wire" + wirenet "perun.network/go-perun/wire/net" + perunio "perun.network/go-perun/wire/perunio/serializer" + ctxtest "polycry.pt/poly-go/context/test" + "polycry.pt/poly-go/test" +) + +const timeout = 100 * time.Millisecond + +func TestExchangeAddrs_ConnFail(t *testing.T) { + rng := test.Prng(t) + a, _ := newPipeConnPair() + a.Close() + addr, err := wirenet.ExchangeAddrsPassive(context.Background(), NewRandomAccount(rng), a) + assert.Nil(t, addr) + assert.Error(t, err) +} + +func TestExchangeAddrs_Success(t *testing.T) { + rng := test.Prng(t) + conn0, conn1 := newPipeConnPair() + defer conn0.Close() + account0, account1 := NewRandomAccount(rng), NewRandomAccount(rng) + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + defer conn1.Close() + + recvAddr0, err := wirenet.ExchangeAddrsPassive(context.Background(), account1, conn1) + assert.NoError(t, err) + assert.True(t, recvAddr0.Equal(account0.Address())) + }() + + err := wirenet.ExchangeAddrsActive(context.Background(), account0, account1.Address(), conn0) + assert.NoError(t, err) + + wg.Wait() +} + +func TestExchangeAddrs_Timeout(t *testing.T) { + rng := test.Prng(t) + a, _ := newPipeConnPair() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + ctxtest.AssertTerminates(t, 2*timeout, func() { + addr, err := wirenet.ExchangeAddrsPassive(ctx, NewRandomAccount(rng), a) + assert.Nil(t, addr) + assert.Error(t, err) + }) +} + +func TestExchangeAddrs_BogusMsg(t *testing.T) { + rng := test.Prng(t) + acc := NewRandomAccount(rng) + conn := newMockConn() + conn.recvQueue <- newRandomEnvelope(rng, wire.NewPingMsg()) + addr, err := wirenet.ExchangeAddrsPassive(context.Background(), acc, conn) + + assert.Error(t, err, "ExchangeAddrs should error when peer sends a non-AuthResponseMsg") + assert.Nil(t, addr) +} + +// newPipeConnPair creates endpoints that are connected via pipes. +func newPipeConnPair() (a wirenet.Conn, b wirenet.Conn) { + c0, c1 := net.Pipe() + ser := perunio.Serializer() + return wirenet.NewIoConn(c0, ser), wirenet.NewIoConn(c1, ser) +} + +// NewRandomEnvelope returns an envelope around message m with random sender and +// recipient generated using randomness from rng. +func newRandomEnvelope(rng *rand.Rand, m wire.Msg) *wire.Envelope { + return &wire.Envelope{ + Sender: NewRandomAddress(rng), + Recipient: NewRandomAddress(rng), + Msg: m, + } +}