Skip to content

Commit a15081a

Browse files
authored
opt: release outboundBuffer immediately in AsyncWrite(v) if Conn is closed (#673)
Fixes #672
1 parent e9a1101 commit a15081a

File tree

3 files changed

+90
-5
lines changed

3 files changed

+90
-5
lines changed

connection_unix.go

+2
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ func (c *conn) asyncWrite(a any) (err error) {
252252
}()
253253

254254
if !c.opened {
255+
c.outboundBuffer.Release() // release all remaining bytes in the outbound buffer
255256
return net.ErrClosed
256257
}
257258

@@ -273,6 +274,7 @@ func (c *conn) asyncWritev(a any) (err error) {
273274
}()
274275

275276
if !c.opened {
277+
c.outboundBuffer.Release() // release all remaining bytes in the outbound buffer
276278
return net.ErrClosed
277279
}
278280

connection_windows.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -421,21 +421,21 @@ var workerPool = nonBlockingPool{Pool: goPool.Default()}
421421
// func (c *conn) Gfd() gfd.GFD { return gfd.GFD{} }
422422

423423
func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error {
424-
_, err := c.Write(buf)
425-
426-
callback := func() error {
424+
fn := func() error {
425+
_, err := c.Write(buf)
427426
if cb != nil {
428427
_ = cb(c, err)
429428
}
430429
return err
431430
}
432431

432+
var err error
433433
select {
434-
case c.loop.ch <- callback:
434+
case c.loop.ch <- fn:
435435
default:
436436
// If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop.
437437
err = workerPool.Go(func() {
438-
c.loop.ch <- callback
438+
c.loop.ch <- fn
439439
})
440440
}
441441

gnet_test.go

+83
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,89 @@ func TestMultiInstLoggerRace(t *testing.T) {
15321532
assert.ErrorIs(t, g.Wait(), errorx.ErrUnsupportedProtocol)
15331533
}
15341534

1535+
type testDisconnectedAsyncWriteServer struct {
1536+
BuiltinEventEngine
1537+
tester *testing.T
1538+
addr string
1539+
writev, clientStarted bool
1540+
exit atomic.Bool
1541+
}
1542+
1543+
func (t *testDisconnectedAsyncWriteServer) OnTraffic(c Conn) Action {
1544+
_, err := c.Next(0)
1545+
require.NoErrorf(t.tester, err, "c.Next error: %v", err)
1546+
1547+
go func() {
1548+
for range time.Tick(100 * time.Millisecond) {
1549+
if t.exit.Load() {
1550+
break
1551+
}
1552+
1553+
if t.writev {
1554+
err = c.AsyncWritev([][]byte{[]byte("hello"), []byte("hello")}, func(_ Conn, err error) error {
1555+
if err == nil {
1556+
return nil
1557+
}
1558+
1559+
require.ErrorIsf(t.tester, err, net.ErrClosed, "expected error: %v, but got: %v", net.ErrClosed, err)
1560+
t.exit.Store(true)
1561+
return nil
1562+
})
1563+
} else {
1564+
err = c.AsyncWrite([]byte("hello"), func(_ Conn, err error) error {
1565+
if err == nil {
1566+
return nil
1567+
}
1568+
1569+
require.ErrorIsf(t.tester, err, net.ErrClosed, "expected error: %v, but got: %v", net.ErrClosed, err)
1570+
t.exit.Store(true)
1571+
return nil
1572+
})
1573+
}
1574+
1575+
if err != nil {
1576+
return
1577+
}
1578+
}
1579+
}()
1580+
1581+
return None
1582+
}
1583+
1584+
func (t *testDisconnectedAsyncWriteServer) OnTick() (delay time.Duration, action Action) {
1585+
delay = 500 * time.Millisecond
1586+
1587+
if t.exit.Load() {
1588+
action = Shutdown
1589+
return
1590+
}
1591+
1592+
if !t.clientStarted {
1593+
t.clientStarted = true
1594+
go func() {
1595+
c, err := net.Dial("tcp", t.addr)
1596+
require.NoError(t.tester, err)
1597+
_, err = c.Write([]byte("hello"))
1598+
require.NoError(t.tester, err)
1599+
require.NoError(t.tester, c.Close())
1600+
}()
1601+
}
1602+
return
1603+
}
1604+
1605+
func TestDisconnectedAsyncWrite(t *testing.T) {
1606+
t.Run("async-write", func(t *testing.T) {
1607+
events := &testDisconnectedAsyncWriteServer{tester: t, addr: ":10000"}
1608+
err := Run(events, "tcp://:10000", WithTicker(true))
1609+
assert.NoError(t, err)
1610+
})
1611+
t.Run("async-writev", func(t *testing.T) {
1612+
events := &testDisconnectedAsyncWriteServer{tester: t, addr: ":10001", writev: true}
1613+
err := Run(events, "tcp://:10001", WithTicker(true))
1614+
assert.NoError(t, err)
1615+
})
1616+
}
1617+
15351618
var errIncompletePacket = errors.New("incomplete packet")
15361619

15371620
type simServer struct {

0 commit comments

Comments
 (0)