Skip to content

Commit

Permalink
feat: add DialContext and EnrollContext for Client (#543)
Browse files Browse the repository at this point in the history
Fixes #541
  • Loading branch information
linfeip authored Mar 19, 2024
1 parent a3a2b7a commit 54f81b6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
71 changes: 23 additions & 48 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"math/rand"
"net"
"sync"
"sync/atomic"
"testing"
"time"
Expand All @@ -21,12 +20,17 @@ import (
goPool "github.com/panjf2000/gnet/v2/pkg/pool/goroutine"
)

type connHandler struct {
network string
rspCh chan []byte
data []byte
}

type clientEvents struct {
*BuiltinEventEngine
tester *testing.T
svr *testClientServer
packetLen int
rspChMap sync.Map
}

func (ev *clientEvents) OnBoot(e Engine) Action {
Expand All @@ -37,13 +41,6 @@ func (ev *clientEvents) OnBoot(e Engine) Action {
return None
}

func (ev *clientEvents) OnOpen(c Conn) ([]byte, Action) {
c.SetContext([]byte{})
rspCh := make(chan []byte, 1)
ev.rspChMap.Store(c.LocalAddr().String(), rspCh)
return nil, None
}

func (ev *clientEvents) OnClose(Conn, error) Action {
if ev.svr != nil {
if atomic.AddInt32(&ev.svr.clientActive, -1) == 0 {
Expand All @@ -54,24 +51,18 @@ func (ev *clientEvents) OnClose(Conn, error) Action {
}

func (ev *clientEvents) OnTraffic(c Conn) (action Action) {
ctx := c.Context()
var p []byte
if ctx != nil {
p = ctx.([]byte)
} else { // UDP
handler := c.Context().(*connHandler)
if handler.network == "udp" {
ev.packetLen = 1024
}
buf, err := c.Next(-1)
assert.NoError(ev.tester, err)
p = append(p, buf...)
if len(p) < ev.packetLen {
c.SetContext(p)
handler.data = append(handler.data, buf...)
if len(handler.data) < ev.packetLen {
return
}
v, _ := ev.rspChMap.Load(c.LocalAddr().String())
rspCh := v.(chan []byte)
rspCh <- p
c.SetContext([]byte{})
handler.rspCh <- handler.data
handler.data = nil
return
}

Expand Down Expand Up @@ -200,7 +191,6 @@ func TestServeWithGnetClient(t *testing.T) {
type testClientServer struct {
*BuiltinEventEngine
client *Client
clientEV *clientEvents
tester *testing.T
eng Engine
network string
Expand Down Expand Up @@ -277,7 +267,7 @@ func (s *testClientServer) OnTick() (delay time.Duration, action Action) {
if i%2 == 0 {
netConn = true
}
go startGnetClient(s.tester, s.client, s.clientEV, s.network, s.addr, s.multicore, s.async, netConn)
go startGnetClient(s.tester, s.client, s.network, s.addr, s.multicore, s.async, netConn)
}
}
if s.network == "udp" && atomic.LoadInt32(&s.clientActive) == 0 {
Expand All @@ -298,9 +288,9 @@ func testServeWithGnetClient(t *testing.T, network, addr string, reuseport, reus
workerPool: goPool.Default(),
}
var err error
ts.clientEV = &clientEvents{tester: t, packetLen: streamLen, svr: ts}
clientEV := &clientEvents{tester: t, packetLen: streamLen, svr: ts}
ts.client, err = NewClient(
ts.clientEV,
clientEV,
WithLogLevel(logging.DebugLevel),
WithLockOSThread(true),
WithTicker(true),
Expand All @@ -324,44 +314,29 @@ func testServeWithGnetClient(t *testing.T, network, addr string, reuseport, reus
assert.NoError(t, err)
}

func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr string, multicore, async, netDial bool) {
func startGnetClient(t *testing.T, cli *Client, network, addr string, multicore, async, netDial bool) {
rand.Seed(time.Now().UnixNano())
var (
c Conn
err error
)
handler := &connHandler{
network: network,
rspCh: make(chan []byte, 1),
}
if netDial {
var netConn net.Conn
netConn, err = NetDial(network, addr)
require.NoError(t, err)
c, err = cli.Enroll(netConn)
c, err = cli.EnrollContext(netConn, handler)
} else {
c, err = cli.Dial(network, addr)
c, err = cli.DialContext(network, addr, handler)
}
require.NoError(t, err)
defer c.Close()
err = c.Wake(nil)
require.NoError(t, err)
var rspCh chan []byte
if network == "udp" {
rspCh = make(chan []byte, 1)
ev.rspChMap.Store(c.LocalAddr().String(), rspCh)
} else {
var (
v interface{}
ok bool
)
start := time.Now()
for time.Since(start) < time.Second {
v, ok = ev.rspChMap.Load(c.LocalAddr().String())
if ok {
break
}
time.Sleep(10 * time.Millisecond)
}
require.True(t, ok)
rspCh = v.(chan []byte)
}
rspCh := handler.rspCh
duration := time.Duration((rand.Float64()*2+1)*float64(time.Second)) / 2
t.Logf("test duration: %dms", duration/time.Millisecond)
start := time.Now()
Expand Down
13 changes: 12 additions & 1 deletion client_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,25 @@ func (cli *Client) Stop() (err error) {

// Dial is like net.Dial().
func (cli *Client) Dial(network, address string) (Conn, error) {
return cli.DialContext(network, address, nil)
}

// DialContext is like Dial but also accepts an empty interface ctx that can be obtained later via Conn.Context.
func (cli *Client) DialContext(network, address string, ctx interface{}) (Conn, error) {
c, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return cli.Enroll(c)
return cli.EnrollContext(c, ctx)
}

// Enroll converts a net.Conn to gnet.Conn and then adds it into Client.
func (cli *Client) Enroll(c net.Conn) (Conn, error) {
return cli.EnrollContext(c, nil)
}

// EnrollContext is like Enroll but also accepts an empty interface ctx that can be obtained later via Conn.Context.
func (cli *Client) EnrollContext(c net.Conn, ctx interface{}) (Conn, error) {
defer c.Close()

sc, ok := c.(syscall.Conn)
Expand Down Expand Up @@ -217,6 +227,7 @@ func (cli *Client) Enroll(c net.Conn) (Conn, error) {
default:
return nil, errorx.ErrUnsupportedProtocol
}
gc.SetContext(ctx)
err = cli.el.poller.UrgentTrigger(cli.el.register, gc)
if err != nil {
gc.Close()
Expand Down
14 changes: 13 additions & 1 deletion client_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ func unixAddr(addr string) string {
}

func (cli *Client) Dial(network, addr string) (Conn, error) {
return cli.DialContext(network, addr, nil)
}

func (cli *Client) DialContext(network, addr string, ctx interface{}) (Conn, error) {
var (
c net.Conn
err error
Expand All @@ -135,10 +139,14 @@ func (cli *Client) Dial(network, addr string) (Conn, error) {
return nil, err
}
}
return cli.Enroll(c)
return cli.EnrollContext(c, ctx)
}

func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) {
return cli.EnrollContext(nc, nil)
}

func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err error) {
switch v := nc.(type) {
case *net.TCPConn:
if cli.opts.TCPNoDelay == TCPNoDelay {
Expand All @@ -156,6 +164,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) {
}

c := newTCPConn(nc, cli.el)
c.SetContext(ctx)
cli.el.ch <- c
go func(c *conn, tc net.Conn, el *eventloop) {
var buffer [0x10000]byte
Expand All @@ -171,6 +180,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) {
gc = c
case *net.UnixConn:
c := newTCPConn(nc, cli.el)
c.SetContext(ctx)
cli.el.ch <- c
go func(c *conn, uc net.Conn, el *eventloop) {
var buffer [0x10000]byte
Expand All @@ -192,6 +202,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) {
gc = c
case *net.UDPConn:
c := newUDPConn(cli.el, nc.LocalAddr(), nc.RemoteAddr())
c.SetContext(ctx)
c.rawConn = nc
go func(uc net.Conn, el *eventloop) {
var buffer [0x10000]byte
Expand All @@ -201,6 +212,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) {
return
}
c := newUDPConn(cli.el, uc.LocalAddr(), uc.RemoteAddr())
c.SetContext(ctx)
c.rawConn = uc
el.ch <- packUDPConn(c, buffer[:n])
}
Expand Down

0 comments on commit 54f81b6

Please sign in to comment.