From 5b7016e577fb31c4ab3d50c80806e8ab94a45c44 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 12 Sep 2024 14:07:34 +0530 Subject: [PATCH] add support for websocket --- core/network/conn.go | 2 +- core/network/mux.go | 4 +- p2p/muxer/yamux/conn.go | 4 +- p2p/muxer/yamux/stream.go | 12 +- p2p/net/swarm/swarm.go | 8 ++ p2p/net/swarm/swarm_conn.go | 4 +- p2p/net/upgrader/conn.go | 7 ++ p2p/test/transport/transport_test.go | 162 ++++++++++++++++++++++++++- p2p/transport/quic/conn.go | 4 +- p2p/transport/quic/stream.go | 9 ++ p2p/transport/websocket/conn.go | 8 ++ 11 files changed, 207 insertions(+), 17 deletions(-) diff --git a/core/network/conn.go b/core/network/conn.go index 218c41bf3d..06441558b7 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -16,7 +16,7 @@ type ConnErrorCode uint32 type ConnError struct { Remote bool - ErrorCode uint32 + ErrorCode ConnErrorCode } func (c *ConnError) Error() string { diff --git a/core/network/mux.go b/core/network/mux.go index 0789bfa183..4f584bd591 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -109,10 +109,10 @@ type MuxedConn interface { AcceptStream() (MuxedStream, error) } -type ConnWithErrorer interface { +type CloseWithErrorer interface { // CloseWithError closes the connection with errCode. The errCode is sent // to the peer. - ConnWithError(errCode ConnErrorCode) error + CloseWithError(errCode ConnErrorCode) error } // Multiplexer wraps a net.Conn with a stream multiplexing diff --git a/p2p/muxer/yamux/conn.go b/p2p/muxer/yamux/conn.go index aa7d43d770..4531771842 100644 --- a/p2p/muxer/yamux/conn.go +++ b/p2p/muxer/yamux/conn.go @@ -36,7 +36,7 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { s, err := c.yamux().OpenStream(ctx) if err != nil { - return nil, err + return nil, parseResetError(err) } return (*stream)(s), nil @@ -45,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { s, err := c.yamux().AcceptStream() - return (*stream)(s), err + return (*stream)(s), parseResetError(err) } func (c *conn) yamux() *yamux.Session { diff --git a/p2p/muxer/yamux/stream.go b/p2p/muxer/yamux/stream.go index f9069e8eec..b588c7c2b8 100644 --- a/p2p/muxer/yamux/stream.go +++ b/p2p/muxer/yamux/stream.go @@ -18,11 +18,13 @@ func parseResetError(err error) error { if err == nil { return err } - if errors.Is(err, yamux.ErrStreamReset) { - se := &yamux.StreamError{} - if errors.As(err, &se) { - return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)} - } + se := &yamux.StreamError{} + if errors.As(err, &se) { + return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)} + } + ce := &yamux.GoAwayError{} + if errors.As(err, &ce) { + return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode)} } return err } diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 0127555552..89bac7766d 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -838,6 +838,14 @@ func (c connWithMetrics) Close() error { return c.CapableConn.Close() } +func (c connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error { + c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr()) + if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return c.CapableConn.Close() +} + func (c connWithMetrics) Stat() network.ConnStats { if cs, ok := c.CapableConn.(network.ConnStat); ok { return cs.Stat() diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 44827afd7b..b7cc46fb71 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -81,9 +81,7 @@ func (c *Conn) doClose(errCode network.ConnErrorCode) { c.streams.Unlock() if errCode != 0 { - if ce, ok := c.conn.(interface { - CloseWithError(network.ConnErrorCode) error - }); ok { + if ce, ok := c.conn.(network.CloseWithErrorer); ok { c.err = ce.CloseWithError(errCode) } else { c.err = c.conn.Close() diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 1c23a01aed..18e1e6a931 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -63,3 +63,10 @@ func (t *transportConn) ConnState() network.ConnectionState { UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation, } } + +func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error { + if ce, ok := t.MuxedConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return t.Close() +} diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 69d8d7f233..ac9ec89955 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -804,8 +804,8 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { func TestStreamErrorCode(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { - if tc.Name != "QUIC" && tc.Name != "TCP / TLS / Yamux" && tc.Name != "WebRTC" { - t.Skipf("skipping: %s, only implemented for QUIC", tc.Name) + if tc.Name == "WebTransport" { + t.Skipf("skipping: %s, not implemented", tc.Name) return } server := tc.HostGenerator(t, TransportTestCaseOpts{}) @@ -841,6 +841,9 @@ func TestStreamErrorCode(t *testing.T) { } _, err = s.Read(b) errCh <- err + + _, err = s.Write(b) + errCh <- err }) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -865,8 +868,163 @@ func TestStreamErrorCode(t *testing.T) { _, err = s.Write(buf) checkError(err, 42, false) + err = <-errCh // read error + checkError(err, 42, true) + + err = <-errCh // write error + checkError(err, 42, true) + }) + } +} + +// TestStreamErrorCodeConnClosed tests correctness for resetting stream with errors +func TestStreamErrorCodeConnClosed(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.ConnErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + ce := &network.ConnError{} + if errors.As(err, &ce) { + require.Equal(t, code, ce.ErrorCode) + require.Equal(t, remote, ce.Remote) + return + } + t.Fatal("expected network.ConnError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + _, err = s.Read(b) + errCh <- err + + _, err = s.Write(b) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + _, err = s.Read(buf) + checkError(err, 42, false) + + _, err = s.Write(buf) + checkError(err, 42, false) + + err = <-errCh + checkError(err, 42, true) + + err = <-errCh + checkError(err, 42, true) + }) + } +} + +// TestConnectionErrorCode tests correctness for resetting stream with errors +func TestConnectionErrorCode(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.ConnErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + ce := &network.ConnError{} + if errors.As(err, &ce) { + require.Equal(t, code, ce.ErrorCode) + require.Equal(t, remote, ce.Remote) + return + } + t.Fatal("expected network.ConnError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + + _, err = s.Read(b) + if !assert.Error(t, err) { + return + } + _, err = s.Conn().NewStream(context.Background()) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + str, err := s.Conn().NewStream(context.Background()) + require.Nil(t, str) + checkError(err, 42, false) + err = <-errCh checkError(err, 42, true) + }) } } diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index ce8e0eca94..8b381d8eda 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -61,7 +61,7 @@ func (c *conn) allowWindowIncrease(size uint64) bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { qstr, err := c.quicConn.OpenStreamSync(ctx) if err != nil { - return nil, err + return nil, parseStreamError(err) } return &stream{Stream: qstr}, nil } @@ -70,7 +70,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { func (c *conn) AcceptStream() (network.MuxedStream, error) { qstr, err := c.quicConn.AcceptStream(context.Background()) if err != nil { - return nil, err + return nil, parseStreamError(err) } return &stream{Stream: qstr}, nil } diff --git a/p2p/transport/quic/stream.go b/p2p/transport/quic/stream.go index 0bdc32f1d0..57d6577f3f 100644 --- a/p2p/transport/quic/stream.go +++ b/p2p/transport/quic/stream.go @@ -27,6 +27,7 @@ func parseStreamError(err error) error { if errors.As(err, &se) { code := se.ErrorCode if code > math.MaxUint32 { + // TODO(sukunrt): do we need this? code = reset } err = &network.StreamError{ @@ -34,6 +35,14 @@ func parseStreamError(err error) error { Remote: se.Remote, } } + ae := &quic.ApplicationError{} + if errors.As(err, &ae) { + code := ae.ErrorCode + err = &network.ConnError{ + ErrorCode: network.ConnErrorCode(code), + Remote: ae.Remote, + } + } return err } diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..cd30442f60 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -162,3 +162,11 @@ func (c *capableConn) ConnState() network.ConnectionState { cs.Transport = "websocket" return cs } + +// CloseWithError implements network.CloseWithErrorer +func (c *capableConn) CloseWithError(errCode network.ConnErrorCode) error { + if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return c.Close() +}