Skip to content

Commit

Permalink
client: fix direct dial 4/6 dual stack error
Browse files Browse the repository at this point in the history
  • Loading branch information
rkonfj committed Apr 21, 2024
1 parent 66bbb5b commit 0e2b579
Showing 1 changed file with 46 additions and 24 deletions.
70 changes: 46 additions & 24 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,58 +52,79 @@ func NewTohClient(options Options) (*TohClient, error) {
Exchange: c.dnsExchange,
}
c.directNetDial = func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
dial6 := func() net.Conn {
for _, addr := range c.serverIPv6s {
conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort))
if err == nil {
return conn
}
}
return nil
}
dial4 := func() net.Conn {
for _, addr := range c.serverIPv4s {
conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort))
if err == nil {
return conn
}
}
return nil
}
if len(c.serverIPv6s) == 0 && len(c.serverIPv4s) == 0 {
var host string
host, c.serverPort, err = net.SplitHostPort(addr)
if err != nil {
return
}

ipv4Ok := make(chan struct{})
ipv6Ok := make(chan struct{})
connChan := make(chan net.Conn)
go func() {
defer func() { recover() }()
c.serverIPv6s, err = D.LookupIP6(host)
if err != nil {
logrus.Debugf("lookup6 for %s: %s", host, err)
time.AfterFunc(5*time.Second, func() { close(ipv6Ok) })
return
}
if len(c.serverIPv6s) > 0 {
ipv6Ok <- struct{}{}
if conn := dial6(); conn != nil {
connChan <- conn
}
}
}()
go func() {
defer func() { recover() }()
c.serverIPv4s, err = D.LookupIP4(host)
if err != nil {
logrus.Debugf("lookup4 for %s: %s", host, err)
time.AfterFunc(5*time.Second, func() { close(ipv4Ok) })
return
}
if len(c.serverIPv4s) > 0 {
ipv4Ok <- struct{}{}
if conn := dial4(); conn != nil {
connChan <- conn
}
}
}()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
defer close(connChan)
select {
case <-ipv4Ok:
case <-ipv6Ok:
}
}
for _, addr := range c.serverIPv6s { // ipv6 first
conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort))
if err == nil {
return
case conn := <-connChan:
return conn, nil
case <-ctx.Done():
return nil, spec.ErrDNSRecordNotFound
}
}
for _, addr := range c.serverIPv4s { // fallback to ipv4
conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort))
if err == nil {
return
}

if conn := dial6(); conn != nil { // ipv6 first
return conn, nil
}
if err == nil {
err = spec.ErrDNSRecordNotFound

if conn := dial4(); conn != nil { // fallback to ipv4
return conn, nil
}
return

return nil, spec.ErrDNSRecordNotFound
}
c.directHttpClient = &http.Client{
Transport: &http.Transport{
Expand All @@ -125,10 +146,11 @@ func (c *TohClient) LookupIP(host string) (ips []net.IP, err error) {
var ip6 []net.IP
go func() {
defer wg.Done()
_ips, e6 := c.LookupIP6(host)
if e6 == nil {
_ips, err := c.LookupIP6(host)
if err == nil {
ip6 = append(ip6, _ips...)
}
e6 = err
}()
_ips, e4 := c.LookupIP4(host)
if e4 == nil {
Expand Down

0 comments on commit 0e2b579

Please sign in to comment.