diff --git a/client/client.go b/client/client.go index 5f0799a..2ace021 100644 --- a/client/client.go +++ b/client/client.go @@ -52,6 +52,24 @@ func NewTohClient(options Options) (*TohClient, error) { Exchange: c.dnsExchange, } c.directNetDial = func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + dial6 := func() net.Conn { + for _, addr := range c.serverIPv6s { + conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort)) + if err == nil { + return conn + } + } + return nil + } + dial4 := func() net.Conn { + for _, addr := range c.serverIPv4s { + conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort)) + if err == nil { + return conn + } + } + return nil + } if len(c.serverIPv6s) == 0 && len(c.serverIPv4s) == 0 { var host string host, c.serverPort, err = net.SplitHostPort(addr) @@ -59,51 +77,54 @@ func NewTohClient(options Options) (*TohClient, error) { return } - ipv4Ok := make(chan struct{}) - ipv6Ok := make(chan struct{}) + connChan := make(chan net.Conn) go func() { + defer func() { recover() }() c.serverIPv6s, err = D.LookupIP6(host) if err != nil { logrus.Debugf("lookup6 for %s: %s", host, err) - time.AfterFunc(5*time.Second, func() { close(ipv6Ok) }) return } if len(c.serverIPv6s) > 0 { - ipv6Ok <- struct{}{} + if conn := dial6(); conn != nil { + connChan <- conn + } } }() go func() { + defer func() { recover() }() c.serverIPv4s, err = D.LookupIP4(host) if err != nil { logrus.Debugf("lookup4 for %s: %s", host, err) - time.AfterFunc(5*time.Second, func() { close(ipv4Ok) }) return } if len(c.serverIPv4s) > 0 { - ipv4Ok <- struct{}{} + if conn := dial4(); conn != nil { + connChan <- conn + } } }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + defer close(connChan) select { - case <-ipv4Ok: - case <-ipv6Ok: - } - } - for _, addr := range c.serverIPv6s { // ipv6 first - conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort)) - if err == nil { - return + case conn := <-connChan: + return conn, nil + case <-ctx.Done(): + return nil, spec.ErrDNSRecordNotFound } } - for _, addr := range c.serverIPv4s { // fallback to ipv4 - conn, err = (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(addr.String(), c.serverPort)) - if err == nil { - return - } + + if conn := dial6(); conn != nil { // ipv6 first + return conn, nil } - if err == nil { - err = spec.ErrDNSRecordNotFound + + if conn := dial4(); conn != nil { // fallback to ipv4 + return conn, nil } - return + + return nil, spec.ErrDNSRecordNotFound } c.directHttpClient = &http.Client{ Transport: &http.Transport{ @@ -125,10 +146,11 @@ func (c *TohClient) LookupIP(host string) (ips []net.IP, err error) { var ip6 []net.IP go func() { defer wg.Done() - _ips, e6 := c.LookupIP6(host) - if e6 == nil { + _ips, err := c.LookupIP6(host) + if err == nil { ip6 = append(ip6, _ips...) } + e6 = err }() _ips, e4 := c.LookupIP4(host) if e4 == nil {