Skip to content

Commit cdf3a6a

Browse files
committed
Add support for -4 and -6 switches.
1 parent db0257c commit cdf3a6a

12 files changed

+348
-29
lines changed

.golangci.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@
125125
text = "(tlsFeatureExtensionOID|ocspMustStapleFeature) is a global variable"
126126
[[issues.exclude-rules]]
127127
path = "challenge/dns01/nameserver.go"
128-
text = "(defaultNameservers|recursiveNameservers|fqdnSoaCache|muFqdnSoaCache) is a global variable"
128+
text = "(defaultNameservers|recursiveNameservers|currentNetworkStack|fqdnSoaCache|muFqdnSoaCache) is a global variable"
129129
[[issues.exclude-rules]]
130130
path = "challenge/dns01/nameserver_.+.go"
131131
text = "dnsTimeout is a global variable"

challenge/dns01/nameserver.go

+47-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ var defaultNameservers = []string{
2626
// recursiveNameservers are used to pre-check DNS propagation.
2727
var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
2828

29+
// NetworkStack is used to indicate which IP stack should be used for DNS
30+
// queries. Valid values are DefaultNetworkStack, IPv4Only, and IPv6Only.
31+
type NetworkStack int
32+
33+
const (
34+
// DefaultNetworkStack indicates that both IPv4 and IPv6 should be allowed.
35+
// This setting lets the OS determine which IP stack to use.
36+
DefaultNetworkStack NetworkStack = iota
37+
// IPv4Only forces DNS queries to only happen over the IPv4 stack.
38+
IPv4Only
39+
// IPv6Only forces DNS queries to only happen over the IPv6 stack.
40+
IPv6Only
41+
)
42+
43+
// currentNetworkStack is used to define which IP stack will be used. The default is
44+
// both IPv4 and IPv6. Set to IPv4Only or IPv6Only to select either version.
45+
var currentNetworkStack = DefaultNetworkStack
46+
2947
// soaCacheEntry holds a cached SOA record (only selected fields).
3048
type soaCacheEntry struct {
3149
zone string // zone apex (a domain name)
@@ -67,6 +85,11 @@ func AddRecursiveNameservers(nameservers []string) ChallengeOption {
6785
}
6886
}
6987

88+
// SetNetworkStack defines the IP stack that will be used for DNS queries.
89+
func SetNetworkStack(network NetworkStack) {
90+
currentNetworkStack = network
91+
}
92+
7093
// getNameservers attempts to get systems nameservers before falling back to the defaults.
7194
func getNameservers(path string, defaults []string) []string {
7295
config, err := dns.ClientConfigFromFile(path)
@@ -249,12 +272,33 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
249272
return m
250273
}
251274

275+
// getNetwork interprets the NetworkStack setting in relation to the desired
276+
// protocol. The proto value should be either "udp" or "tcp".
277+
func getNetwork(proto string) string {
278+
// The dns client passes whatever value is set in [dns.Client.Net] to
279+
// the [net.Dialer] (https://github.com/miekg/dns/blob/fe20d5d/client.go#L119-L141).
280+
// And the [net.Dialer] accepts strings such as "udp4" or "tcp6"
281+
// (https://cs.opensource.google/go/go/+/refs/tags/go1.18.9:src/net/dial.go;l=167-182).
282+
if currentNetworkStack == IPv4Only {
283+
return proto + "4"
284+
}
285+
if currentNetworkStack == IPv6Only {
286+
return proto + "6"
287+
}
288+
return proto
289+
}
290+
252291
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
253-
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
292+
network := getNetwork("udp")
293+
udp := &dns.Client{Net: network, Timeout: dnsTimeout}
254294
in, _, err := udp.Exchange(m, ns)
255295

256-
if in != nil && in.Truncated {
257-
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
296+
network = getNetwork("tcp")
297+
// We can encounter a net.OpError if the nameserver is not listening
298+
// on UDP at all, i.e. net.Dial could not make a connection.
299+
_, isOpErr := err.(*net.OpError)
300+
if (in != nil && in.Truncated) || isOpErr {
301+
tcp := &dns.Client{Net: network, Timeout: dnsTimeout}
258302
// If the TCP request succeeds, the err will reset to nil
259303
in, _, err = tcp.Exchange(m, ns)
260304
}

challenge/dns01/nameserver_test.go

+138
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,151 @@
11
package dns01
22

33
import (
4+
"fmt"
5+
getport "github.com/jsumners/go-getport"
6+
"github.com/miekg/dns"
7+
"net"
48
"sort"
9+
"sync"
510
"testing"
611

712
"github.com/stretchr/testify/assert"
813
"github.com/stretchr/testify/require"
914
)
1015

16+
type testDnsHandler struct{}
17+
type testDnsServer struct {
18+
*dns.Server
19+
getport.PortResult
20+
}
21+
22+
func (handler *testDnsHandler) ServeDNS(writer dns.ResponseWriter, reply *dns.Msg) {
23+
msg := dns.Msg{}
24+
msg.SetReply(reply)
25+
26+
switch reply.Question[0].Qtype {
27+
case dns.TypeA:
28+
msg.Authoritative = true
29+
domain := msg.Question[0].Name
30+
msg.Answer = append(
31+
msg.Answer,
32+
&dns.A{
33+
Hdr: dns.RR_Header{
34+
Name: domain,
35+
Rrtype: dns.TypeA,
36+
Class: dns.ClassINET,
37+
Ttl: 60,
38+
},
39+
A: net.ParseIP("127.0.0.1"),
40+
},
41+
)
42+
}
43+
44+
writer.WriteMsg(&msg)
45+
}
46+
47+
// getTestNameserver constructs a new DNS server on a local address, or set
48+
// of addresses, that responds to an `A` query for `example.com`.
49+
func getTestNameserver(t *testing.T, network string) testDnsServer {
50+
server := &dns.Server{
51+
Handler: new(testDnsHandler),
52+
Net: network,
53+
}
54+
testServer := testDnsServer{
55+
Server: server,
56+
}
57+
58+
var protocol getport.Protocol
59+
var address string
60+
switch network {
61+
case "tcp":
62+
protocol = getport.TCP
63+
address = "0.0.0.0"
64+
case "tcp4":
65+
protocol = getport.TCP4
66+
address = "127.0.0.1"
67+
case "tcp6":
68+
protocol = getport.TCP6
69+
address = "::1"
70+
case "udp":
71+
protocol = getport.UDP
72+
address = "0.0.0.0"
73+
case "udp4":
74+
protocol = getport.UDP4
75+
address = "127.0.0.1"
76+
case "udp6":
77+
protocol = getport.UDP6
78+
address = "::1"
79+
}
80+
portResult, portError := getport.GetPort(protocol, address)
81+
if portError != nil {
82+
t.Error(portError)
83+
return testServer
84+
}
85+
testServer.PortResult = portResult
86+
server.Addr = getport.PortResultToAddress(portResult)
87+
88+
waitLock := sync.Mutex{}
89+
waitLock.Lock()
90+
server.NotifyStartedFunc = waitLock.Unlock
91+
92+
fin := make(chan error, 1)
93+
go func() {
94+
fin <- server.ListenAndServe()
95+
}()
96+
97+
waitLock.Lock()
98+
return testServer
99+
}
100+
101+
func TestSendDNSQuery(t *testing.T) {
102+
t.Run("does udp4 only", func(t *testing.T) {
103+
SetNetworkStack(IPv4Only)
104+
nameserver := getTestNameserver(t, getNetwork("udp"))
105+
defer nameserver.Server.Shutdown()
106+
107+
serverAddress := fmt.Sprintf("127.0.0.1:%d", nameserver.PortResult.Port)
108+
recursiveNameservers = ParseNameservers([]string{serverAddress})
109+
msg := createDNSMsg("example.com.", dns.TypeA, true)
110+
result, queryError := sendDNSQuery(msg, serverAddress)
111+
assert.NoError(t, queryError)
112+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
113+
})
114+
115+
t.Run("does udp6 only", func(t *testing.T) {
116+
SetNetworkStack(IPv6Only)
117+
nameserver := getTestNameserver(t, getNetwork("udp"))
118+
defer nameserver.Server.Shutdown()
119+
120+
serverAddress := fmt.Sprintf("[::1]:%d", nameserver.PortResult.Port)
121+
recursiveNameservers = ParseNameservers([]string{serverAddress})
122+
msg := createDNSMsg("example.com.", dns.TypeA, true)
123+
result, queryError := sendDNSQuery(msg, serverAddress)
124+
assert.NoError(t, queryError)
125+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
126+
})
127+
128+
t.Run("does tcp4 and tcp6", func(t *testing.T) {
129+
SetNetworkStack(DefaultNetworkStack)
130+
nameserver := getTestNameserver(t, getNetwork("tcp"))
131+
defer nameserver.Server.Shutdown()
132+
133+
serverAddress := fmt.Sprintf("[::1]:%d", nameserver.PortResult.Port)
134+
recursiveNameservers = ParseNameservers([]string{serverAddress})
135+
msg := createDNSMsg("example.com.", dns.TypeA, true)
136+
result, queryError := sendDNSQuery(msg, serverAddress)
137+
assert.NoError(t, queryError)
138+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
139+
140+
serverAddress = fmt.Sprintf("127.0.0.1:%d", nameserver.PortResult.Port)
141+
recursiveNameservers = ParseNameservers([]string{serverAddress})
142+
msg = createDNSMsg("example.com.", dns.TypeA, true)
143+
result, queryError = sendDNSQuery(msg, serverAddress)
144+
assert.NoError(t, queryError)
145+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
146+
})
147+
}
148+
11149
func TestLookupNameserversOK(t *testing.T) {
12150
testCases := []struct {
13151
fqdn string

challenge/http01/http_challenge_server.go

+13-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ import (
1212
"github.com/go-acme/lego/v4/log"
1313
)
1414

15+
type ProviderNetwork string
16+
17+
const (
18+
DefaultNetwork = "tcp"
19+
Tcp4Network = "tcp4"
20+
Tcp6Network = "tcp6"
21+
)
22+
1523
// ProviderServer implements ChallengeProvider for `http-01` challenge.
1624
// It may be instantiated without using the NewProviderServer function if
1725
// you want only to use the default values.
@@ -29,12 +37,15 @@ type ProviderServer struct {
2937
// NewProviderServer creates a new ProviderServer on the selected interface and port.
3038
// Setting iface and / or port to an empty string will make the server fall back to
3139
// the "any" interface and port 80 respectively.
32-
func NewProviderServer(iface, port string) *ProviderServer {
40+
func NewProviderServer(iface, port string, network ProviderNetwork) *ProviderServer {
3341
if port == "" {
3442
port = "80"
3543
}
44+
if network == "" {
45+
network = DefaultNetwork
46+
}
3647

37-
return &ProviderServer{network: "tcp", address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}}
48+
return &ProviderServer{network: string(network), address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}}
3849
}
3950

4051
func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer {

challenge/http01/http_challenge_test.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,27 @@ func TestProviderServer_GetAddress(t *testing.T) {
3636
}{
3737
{
3838
desc: "TCP default address",
39-
server: NewProviderServer("", ""),
39+
server: NewProviderServer("", "", ""),
4040
expected: ":80",
4141
},
4242
{
4343
desc: "TCP with explicit port",
44-
server: NewProviderServer("", "8080"),
44+
server: NewProviderServer("", "8080", ""),
4545
expected: ":8080",
4646
},
4747
{
4848
desc: "TCP with host and port",
49-
server: NewProviderServer("localhost", "8080"),
49+
server: NewProviderServer("localhost", "8080", ""),
50+
expected: "localhost:8080",
51+
},
52+
{
53+
desc: "TCP4 with host and port",
54+
server: NewProviderServer("localhost", "8080", Tcp4Network),
55+
expected: "localhost:8080",
56+
},
57+
{
58+
desc: "TCP6 with host and port",
59+
server: NewProviderServer("localhost", "8080", Tcp6Network),
5060
expected: "localhost:8080",
5161
},
5262
{
@@ -70,7 +80,7 @@ func TestProviderServer_GetAddress(t *testing.T) {
7080
func TestChallenge(t *testing.T) {
7181
_, apiURL := tester.SetupFakeAPI(t)
7282

73-
providerServer := NewProviderServer("", "23457")
83+
providerServer := NewProviderServer("", "23457", "")
7484

7585
validate := func(_ *api.Core, _ string, chlng acme.Challenge) error {
7686
uri := "http://localhost" + providerServer.GetAddress() + ChallengePath(chlng.Token)
@@ -199,7 +209,7 @@ func TestChallengeInvalidPort(t *testing.T) {
199209

200210
validate := func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }
201211

202-
solver := NewChallenge(core, validate, NewProviderServer("", "123456"))
212+
solver := NewChallenge(core, validate, NewProviderServer("", "123456", ""))
203213

204214
authz := acme.Authorization{
205215
Identifier: acme.Identifier{
@@ -374,7 +384,7 @@ func testServeWithProxy(t *testing.T, header, extra *testProxyHeader, expectErro
374384

375385
_, apiURL := tester.SetupFakeAPI(t)
376386

377-
providerServer := NewProviderServer("localhost", "23457")
387+
providerServer := NewProviderServer("localhost", "23457", "")
378388
if header != nil {
379389
providerServer.SetProxyHeader(header.name)
380390
}

challenge/tlsalpn01/tls_alpn_challenge_server.go

+20-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ import (
1111
"github.com/go-acme/lego/v4/log"
1212
)
1313

14+
type ProviderNetwork string
15+
16+
const (
17+
DefaultNetwork = "tcp"
18+
Tcp4Network = "tcp4"
19+
Tcp6Network = "tcp6"
20+
)
21+
1422
const (
1523
// ACMETLS1Protocol is the ALPN Protocol ID for the ACME-TLS/1 Protocol.
1624
ACMETLS1Protocol = "acme-tls/1"
@@ -26,14 +34,23 @@ const (
2634
type ProviderServer struct {
2735
iface string
2836
port string
37+
network string
2938
listener net.Listener
3039
}
3140

3241
// NewProviderServer creates a new ProviderServer on the selected interface and port.
3342
// Setting iface and / or port to an empty string will make the server fall back to
3443
// the "any" interface and port 443 respectively.
35-
func NewProviderServer(iface, port string) *ProviderServer {
36-
return &ProviderServer{iface: iface, port: port}
44+
func NewProviderServer(iface, port string, network ProviderNetwork) *ProviderServer {
45+
if port == "" {
46+
port = defaultTLSPort
47+
}
48+
49+
if network == "" {
50+
network = DefaultNetwork
51+
}
52+
53+
return &ProviderServer{iface: iface, port: port, network: string(network)}
3754
}
3855

3956
func (s *ProviderServer) GetAddress() string {
@@ -65,7 +82,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
6582
tlsConf.NextProtos = []string{ACMETLS1Protocol}
6683

6784
// Create the listener with the created tls.Config.
68-
s.listener, err = tls.Listen("tcp", s.GetAddress(), tlsConf)
85+
s.listener, err = tls.Listen(s.network, s.GetAddress(), tlsConf)
6986
if err != nil {
7087
return fmt.Errorf("could not start HTTPS server for challenge: %w", err)
7188
}

0 commit comments

Comments
 (0)