Skip to content

Commit

Permalink
Daily commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pyropy committed Nov 21, 2024
1 parent 4164b48 commit b96f386
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 122 deletions.
34 changes: 2 additions & 32 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import (
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"go.uber.org/mock/gomock"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -98,38 +100,6 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "TCP-Shared / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket-Shared",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws"))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand Down
10 changes: 5 additions & 5 deletions p2p/transport/memory/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ func newConnection(

func (c *conn) Close() error {
c.closed.Store(true)
for id, s := range c.streams {
c.removeStream(id)
for _, s := range c.streams {
//c.removeStream(id)
s.Close()
}

Expand All @@ -76,10 +76,10 @@ func (c *conn) IsClosed() bool {
}

func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
inStream, outStream := newStreamPair()
sl, sr := newStreamPair()

c.streamC <- inStream
return outStream, nil
c.streamC <- sr
return sl, nil
}

func (c *conn) AcceptStream() (network.MuxedStream, error) {
Expand Down
162 changes: 86 additions & 76 deletions p2p/transport/memory/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package memory
import (
"errors"
"io"
"log"
"net"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -14,118 +14,96 @@ import (
type stream struct {
id int64

write chan byte
read chan byte
read *io.PipeReader
write *io.PipeWriter
writeC chan []byte

reset chan struct{}
closeRead chan struct{}
closeWrite chan struct{}
closed atomic.Bool
reset chan struct{}
close chan struct{}
closed chan struct{}

writeErr error
}

var ErrClosed = errors.New("stream closed")

func newStreamPair() (*stream, *stream) {
ra, rb := make(chan byte, 4096), make(chan byte, 4096)
ra, wb := io.Pipe()
rb, wa := io.Pipe()

sa := newStream(rb, wa, network.DirOutbound)
sb := newStream(ra, wb, network.DirInbound)

in := newStream(rb, ra, network.DirInbound)
out := newStream(ra, rb, network.DirOutbound)
return in, out
return sa, sb
}

func newStream(r, w chan byte, _ network.Direction) *stream {
func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream {
s := &stream{
id: streamCounter.Add(1),
read: r,
write: w,
reset: make(chan struct{}, 1),
closeRead: make(chan struct{}, 1),
closeWrite: make(chan struct{}, 1),
id: streamCounter.Add(1),
read: r,
write: w,
writeC: make(chan []byte),
reset: make(chan struct{}, 1),
close: make(chan struct{}, 1),
closed: make(chan struct{}),
}
log.Println("newStream", "id", s.id)

go s.writeLoop()
return s
}

func (s *stream) Write(p []byte) (n int, err error) {
if s.closed.Load() {
return 0, ErrClosed
}
func (s *stream) Write(p []byte) (int, error) {
cpy := make([]byte, len(p))
copy(cpy, p)

select {
case <-s.reset:
return 0, network.ErrReset
case <-s.closeWrite:
return 0, ErrClosed
default:
case <-s.closed:
return 0, s.writeErr
case s.writeC <- cpy:
}

for n < len(p) {
select {
case <-s.closeWrite:
return n, ErrClosed
case <-s.reset:
return n, network.ErrReset
case s.write <- p[n]:
n++
default:
err = io.ErrClosedPipe
return
}
}

return
return len(p), nil
}

func (s *stream) Read(p []byte) (n int, err error) {
if s.closed.Load() {
return 0, ErrClosed
}
return s.read.Read(p)
}

func (s *stream) CloseWrite() error {
select {
case <-s.reset:
return 0, network.ErrReset
case <-s.closeRead:
return 0, ErrClosed
case s.close <- struct{}{}:
default:
}

for n < len(p) {
select {
case <-s.closeRead:
return n, ErrClosed
case <-s.reset:
return n, network.ErrReset
case b, ok := <-s.read:
if !ok {
return n, ErrClosed
}
p[n] = b
n++
default:
return n, io.EOF
}
log.Println("waiting close", "id", s.id)
<-s.closed
log.Println("closed write", "id", s.id)
if !errors.Is(s.writeErr, ErrClosed) {
return s.writeErr
}

return n, nil
}

func (s *stream) CloseWrite() error {
s.closeWrite <- struct{}{}
return nil

}

func (s *stream) CloseRead() error {
s.closeRead <- struct{}{}
return nil
return s.read.CloseWithError(ErrClosed)
}

func (s *stream) Close() error {
s.closed.Store(true)
return nil
_ = s.CloseRead()
return s.CloseWrite()
}

func (s *stream) Reset() error {
s.reset <- struct{}{}
s.write.CloseWithError(network.ErrReset)
s.read.CloseWithError(network.ErrReset)

select {
case s.reset <- struct{}{}:
default:
}
<-s.closed

return nil
}

Expand All @@ -140,3 +118,35 @@ func (s *stream) SetReadDeadline(t time.Time) error {
func (s *stream) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

func (s *stream) writeLoop() {
defer close(s.closed)
defer log.Println("closing write", "id", s.id)

for {
// Reset takes precedent.
select {
case <-s.reset:
s.writeErr = network.ErrReset
return
default:
}

select {
case <-s.reset:
s.writeErr = network.ErrReset
return
case <-s.close:
s.writeErr = s.write.Close()
if s.writeErr == nil {
s.writeErr = ErrClosed
}
return
case p := <-s.writeC:
if _, err := s.write.Write(p); err != nil {
s.writeErr = err
return
}
}
}
}
4 changes: 2 additions & 2 deletions p2p/transport/memory/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

func TestStreamSimpleReadWriteClose(t *testing.T) {
// t.Parallel()
t.Parallel()
clientStr, serverStr := newStreamPair()

// send a foobar from the client
Expand Down Expand Up @@ -49,7 +49,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
}

func TestStreamPartialReads(t *testing.T) {
// t.Parallel()
t.Parallel()
clientStr, serverStr := newStreamPair()

_, err := serverStr.Write([]byte("foobar"))
Expand Down
16 changes: 9 additions & 7 deletions p2p/transport/memory/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,14 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid
return nil, errors.New("failed to get remote public key")
}

inStream, outStream := newStreamPair()
inConn := newConnection(t, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr)
outConn := newConnection(nil, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil)
l.connCh <- outConn
sl, sr := newStreamPair()

return inConn, nil
lconn := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr)
rconn := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil)

l.connCh <- rconn

return lconn, nil
}

func (t *transport) CanDial(addr ma.Multiaddr) bool {
Expand Down Expand Up @@ -186,11 +188,11 @@ func (t *transport) String() string {

func (t *transport) Close() error {
// TODO: Go trough all listeners and close them
memhub.close()
//memhub.close()

for _, c := range t.connections {
c.Close()
delete(t.connections, c.id)
//delete(t.connections, c.id)
}

return nil
Expand Down
2 changes: 2 additions & 0 deletions p2p/transport/memory/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func getTransport(t *testing.T) tpt.Transport {
}

func TestMemoryProtocol(t *testing.T) {
t.Parallel()
tr := getTransport(t)
defer tr.(io.Closer).Close()

Expand All @@ -40,6 +41,7 @@ func TestMemoryProtocol(t *testing.T) {
}

func TestCanDial(t *testing.T) {
t.Parallel()
tr := getTransport(t)
defer tr.(io.Closer).Close()

Expand Down

0 comments on commit b96f386

Please sign in to comment.