diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index 36cd07e46d..4fee368250 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -92,7 +92,9 @@ func (s *Stream) Reset() error { } func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error { - panic("not implemented") + err := s.stream.ResetWithError(errCode) + s.closeAndRemoveStream() + return err } func (s *Stream) closeAndRemoveStream() { diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..f0f0932cd4 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -35,6 +35,7 @@ import ( "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -798,3 +799,74 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { }) } } + +// TestStreamErrorCode tests correctness for resetting stream with errors +func TestStreamErrorCode(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name != "QUIC" { + t.Skip("only implemented for QUIC") + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.StreamErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, code) + require.Equal(t, se.Remote, remote) + return + } + t.Fatal("expected network.StreamError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + _, err = s.Read(b) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.ResetWithError(42) + require.NoError(t, err) + + _, err = s.Read(buf) + checkError(err, 42, false) + + _, err = s.Write(buf) + checkError(err, 42, false) + + err = <-errCh + checkError(err, 42, true) + }) + } +} diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index d3e27a7e16..bf3f7b0751 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -270,6 +270,9 @@ func TestStreams(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { testStreams(t, tc) }) + t.Run(tc.Name, func(t *testing.T) { + testStreamsErrorCode(t, tc) + }) } } @@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) { require.Equal(t, data, []byte("foobar")) } +func testStreamsErrorCode(t *testing.T, tc *connTestCase) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + serverConn, err := ln.Accept() + require.NoError(t, err) + defer serverConn.Close() + + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + err = str.ResetWithError(42) + require.NoError(t, err) + + sstr, err := serverConn.AcceptStream() + require.NoError(t, err) + _, err = io.ReadAll(sstr) + require.Error(t, err) + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, network.StreamErrorCode(42)) + require.True(t, se.Remote) + } else { + t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err) + } + +} + func TestHandshakeFailPeerIDMismatch(t *testing.T) { for _, tc := range connTestCases { t.Run(tc.Name, func(t *testing.T) { diff --git a/p2p/transport/quic/stream.go b/p2p/transport/quic/stream.go index 7e8568a813..0bdc32f1d0 100644 --- a/p2p/transport/quic/stream.go +++ b/p2p/transport/quic/stream.go @@ -20,11 +20,14 @@ type stream struct { var _ network.MuxedStream = &stream{} func parseStreamError(err error) error { + if err == nil { + return err + } se := &quic.StreamError{} - if err != nil && errors.As(err, &se) { + if errors.As(err, &se) { code := se.ErrorCode if code > math.MaxUint32 { - code = 0 + code = reset } err = &network.StreamError{ ErrorCode: network.StreamErrorCode(code),