Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry logic to DNS prober #1267

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ type DNSProbe struct {
ValidateAnswer DNSRRValidator `yaml:"validate_answer_rrs,omitempty"`
ValidateAuthority DNSRRValidator `yaml:"validate_authority_rrs,omitempty"`
ValidateAdditional DNSRRValidator `yaml:"validate_additional_rrs,omitempty"`
Retries int `yaml:"retries,omitempty"`
PerRequestTimeout time.Duration `yaml:"per_request_timeout,omitempty"`
}

type DNSRRValidator struct {
Expand Down
194 changes: 116 additions & 78 deletions prober/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package prober

import (
"context"
"errors"
"net"
"regexp"
"time"
Expand Down Expand Up @@ -146,6 +147,10 @@ func ProbeDNS(ctx context.Context, target string, module config.Module, registry
Name: "probe_dns_query_succeeded",
Help: "Displays whether or not the query was executed successfully",
})
probeDNSRetries := prometheus.NewGauge(prometheus.GaugeOpts{
Name: "probe_dns_retries",
Help: "The number of retries the probe took to complete successfully",
})

for _, lv := range []string{"resolve", "connect", "request"} {
probeDNSDurationGaugeVec.WithLabelValues(lv)
Expand All @@ -156,6 +161,7 @@ func ProbeDNS(ctx context.Context, target string, module config.Module, registry
registry.MustRegister(probeDNSAuthorityRRSGauge)
registry.MustRegister(probeDNSAdditionalRRSGauge)
registry.MustRegister(probeDNSQuerySucceeded)
registry.MustRegister(probeDNSRetries)

qc := uint16(dns.ClassINET)
if module.DNS.QueryClass != "" {
Expand Down Expand Up @@ -220,98 +226,130 @@ func ProbeDNS(ctx context.Context, target string, module config.Module, registry
}
}

client := new(dns.Client)
client.Net = dialProtocol
probeDNSRetries.Set(0)

if module.DNS.DNSOverTLS {
tlsConfig, err := pconfig.NewTLSConfig(&module.DNS.TLSConfig)
if err != nil {
level.Error(logger).Log("msg", "Failed to create TLS configuration", "err", err)
return false
}
if tlsConfig.ServerName == "" {
// Use target-hostname as default for TLS-servername.
tlsConfig.ServerName = targetAddr
for retry := 0; retry <= module.DNS.Retries; retry++ {
if retry > 0 {
level.Info(logger).Log("msg", "Retrying request", "retry", retry)
probeDNSRetries.Inc()
}

client.TLSConfig = tlsConfig
}
client := new(dns.Client)
client.Net = dialProtocol

// Use configured SourceIPAddress.
if len(module.DNS.SourceIPAddress) > 0 {
srcIP := net.ParseIP(module.DNS.SourceIPAddress)
if srcIP == nil {
level.Error(logger).Log("msg", "Error parsing source ip address", "srcIP", module.DNS.SourceIPAddress)
return false
if module.DNS.DNSOverTLS {
tlsConfig, err := pconfig.NewTLSConfig(&module.DNS.TLSConfig)
if err != nil {
level.Error(logger).Log("msg", "Failed to create TLS configuration", "err", err)
// This is not retryable.
return false
}
if tlsConfig.ServerName == "" {
// Use target-hostname as default for TLS-servername.
tlsConfig.ServerName = targetAddr
}

client.TLSConfig = tlsConfig
}
level.Info(logger).Log("msg", "Using local address", "srcIP", srcIP)
client.Dialer = &net.Dialer{}
if module.DNS.TransportProtocol == "tcp" {
client.Dialer.LocalAddr = &net.TCPAddr{IP: srcIP}

// Use configured SourceIPAddress.
if len(module.DNS.SourceIPAddress) > 0 {
srcIP := net.ParseIP(module.DNS.SourceIPAddress)
if srcIP == nil {
level.Error(logger).Log("msg", "Error parsing source ip address", "srcIP", module.DNS.SourceIPAddress)
// This is not retryable.
return false
}
level.Info(logger).Log("msg", "Using local address", "srcIP", srcIP)
client.Dialer = &net.Dialer{}
if module.DNS.TransportProtocol == "tcp" {
client.Dialer.LocalAddr = &net.TCPAddr{IP: srcIP}
} else {
client.Dialer.LocalAddr = &net.UDPAddr{IP: srcIP}
}
}

msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = module.DNS.Recursion
msg.Question = make([]dns.Question, 1)
msg.Question[0] = dns.Question{dns.Fqdn(module.DNS.QueryName), qt, qc}

level.Info(logger).Log("msg", "Making DNS query", "target", targetIP, "dial_protocol", dialProtocol, "query", module.DNS.QueryName, "type", qt, "class", qc)

if module.DNS.PerRequestTimeout == 0 {
timeoutDeadline, _ := ctx.Deadline()
client.Timeout = time.Until(timeoutDeadline)
} else {
client.Dialer.LocalAddr = &net.UDPAddr{IP: srcIP}
client.Timeout = module.DNS.PerRequestTimeout
}
}

msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = module.DNS.Recursion
msg.Question = make([]dns.Question, 1)
msg.Question[0] = dns.Question{dns.Fqdn(module.DNS.QueryName), qt, qc}
requestStart := time.Now()
response, rtt, err := client.ExchangeContext(ctx, msg, targetIP)
// The rtt value returned from client.Exchange includes only the time to
// exchange messages with the server _after_ the connection is created.
// We compute the connection time as the total time for the operation
// minus the time for the actual request rtt.
probeDNSDurationGaugeVec.WithLabelValues("connect").Set((time.Since(requestStart) - rtt).Seconds())
probeDNSDurationGaugeVec.WithLabelValues("request").Set(rtt.Seconds())
if err != nil {
cause := new(net.OpError)
if errors.As(err, &cause) {
switch {
case cause.Timeout():
level.Error(logger).Log("msg", "DNS request timed out", "err", err)
continue
case cause.Temporary():
level.Error(logger).Log("msg", "DNS request encoutered a temporary error", "err", err)
continue
}
}

level.Info(logger).Log("msg", "Making DNS query", "target", targetIP, "dial_protocol", dialProtocol, "query", module.DNS.QueryName, "type", qt, "class", qc)
timeoutDeadline, _ := ctx.Deadline()
client.Timeout = time.Until(timeoutDeadline)
requestStart := time.Now()
response, rtt, err := client.Exchange(msg, targetIP)
// The rtt value returned from client.Exchange includes only the time to
// exchange messages with the server _after_ the connection is created.
// We compute the connection time as the total time for the operation
// minus the time for the actual request rtt.
probeDNSDurationGaugeVec.WithLabelValues("connect").Set((time.Since(requestStart) - rtt).Seconds())
probeDNSDurationGaugeVec.WithLabelValues("request").Set(rtt.Seconds())
if err != nil {
level.Error(logger).Log("msg", "Error while sending a DNS query", "err", err)
return false
}
level.Info(logger).Log("msg", "Got response", "response", response)
level.Error(logger).Log("msg", "Error while sending a DNS query", "err", err)
// This is not retryable.
return false
}
level.Info(logger).Log("msg", "Got response", "response", response)

probeDNSAnswerRRSGauge.Set(float64(len(response.Answer)))
probeDNSAuthorityRRSGauge.Set(float64(len(response.Ns)))
probeDNSAdditionalRRSGauge.Set(float64(len(response.Extra)))
probeDNSQuerySucceeded.Set(1)
probeDNSAnswerRRSGauge.Set(float64(len(response.Answer)))
probeDNSAuthorityRRSGauge.Set(float64(len(response.Ns)))
probeDNSAdditionalRRSGauge.Set(float64(len(response.Extra)))
probeDNSQuerySucceeded.Set(1)

if qt == dns.TypeSOA {
probeDNSSOAGauge = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "probe_dns_serial",
Help: "Returns the serial number of the zone",
})
registry.MustRegister(probeDNSSOAGauge)
if qt == dns.TypeSOA {
probeDNSSOAGauge = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "probe_dns_serial",
Help: "Returns the serial number of the zone",
})
registry.MustRegister(probeDNSSOAGauge)

for _, a := range response.Answer {
if soa, ok := a.(*dns.SOA); ok {
probeDNSSOAGauge.Set(float64(soa.Serial))
for _, a := range response.Answer {
if soa, ok := a.(*dns.SOA); ok {
probeDNSSOAGauge.Set(float64(soa.Serial))
}
}
}
}

if !validRcode(response.Rcode, module.DNS.ValidRcodes, logger) {
return false
}
level.Info(logger).Log("msg", "Validating Answer RRs")
if !validRRs(&response.Answer, &module.DNS.ValidateAnswer, logger) {
level.Error(logger).Log("msg", "Answer RRs validation failed")
return false
}
level.Info(logger).Log("msg", "Validating Authority RRs")
if !validRRs(&response.Ns, &module.DNS.ValidateAuthority, logger) {
level.Error(logger).Log("msg", "Authority RRs validation failed")
return false
}
level.Info(logger).Log("msg", "Validating Additional RRs")
if !validRRs(&response.Extra, &module.DNS.ValidateAdditional, logger) {
level.Error(logger).Log("msg", "Additional RRs validation failed")
return false
if !validRcode(response.Rcode, module.DNS.ValidRcodes, logger) {
return false
}
level.Info(logger).Log("msg", "Validating Answer RRs")
if !validRRs(&response.Answer, &module.DNS.ValidateAnswer, logger) {
level.Error(logger).Log("msg", "Answer RRs validation failed")
return false
}
level.Info(logger).Log("msg", "Validating Authority RRs")
if !validRRs(&response.Ns, &module.DNS.ValidateAuthority, logger) {
level.Error(logger).Log("msg", "Authority RRs validation failed")
return false
}
level.Info(logger).Log("msg", "Validating Additional RRs")
if !validRRs(&response.Extra, &module.DNS.ValidateAdditional, logger) {
level.Error(logger).Log("msg", "Additional RRs validation failed")
return false
}
return true
}
return true

return false
}
92 changes: 92 additions & 0 deletions prober/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"net"
"os"
"runtime"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -474,6 +475,96 @@ func TestServfailDNSResponse(t *testing.T) {
}
}

func TestTimeoutHandling(t *testing.T) {
testcases := []struct {
Probe config.DNSProbe
ExpectedResult bool
}{
{
Probe: config.DNSProbe{
IPProtocol: "ip4",
IPProtocolFallback: false,
QueryName: "example.com",
QueryType: "A",
Retries: 3,
PerRequestTimeout: 100 * time.Millisecond,
},
ExpectedResult: true,
},
{
Probe: config.DNSProbe{
IPProtocol: "ip4",
IPProtocolFallback: false,
QueryName: "example.com",
QueryType: "A",
Retries: 0, // don't retry
PerRequestTimeout: 100 * time.Millisecond, // but use a per-request timeout
},
ExpectedResult: true,
},
{
Probe: config.DNSProbe{
IPProtocol: "ip4",
IPProtocolFallback: false,
QueryName: "example.com",
QueryType: "A",
Retries: 0, // don't retry
PerRequestTimeout: 0, // fallback to context deadline
},
ExpectedResult: true,
},
}

slowHandler := func(n int, sleep time.Duration) func(w dns.ResponseWriter, r *dns.Msg) {
var i atomic.Int32
return func(w dns.ResponseWriter, r *dns.Msg) {
// For the first n requests simply sleep. After that,
// answer normally. This is to simulate a slow
// authoritative server that eventually answers or a
// response packet that gets lost (which is harder to
// simulate correctly).
if i.Add(1) <= int32(n) {
time.Sleep(sleep)
return
}

authoritativeDNSHandler(w, r)
}
}

for _, protocol := range PROTOCOLS {
for _, test := range testcases {
server, addr := startDNSServer(protocol, slowHandler(test.Probe.Retries, 10*test.Probe.PerRequestTimeout))
defer server.Shutdown()

test.Probe.TransportProtocol = protocol
registry := prometheus.NewRegistry()
testCTX, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

result := ProbeDNS(testCTX, addr.String(), config.Module{Timeout: time.Second, DNS: test.Probe}, registry, log.NewNopLogger())
if result != test.ExpectedResult {
t.Fatalf("Test had unexpected result: %v", result)
}

mfs, err := registry.Gather()
if err != nil {
t.Fatal(err)
}

expectedResults := map[string]float64{
"probe_dns_answer_rrs": 1,
"probe_dns_authority_rrs": 2,
"probe_dns_additional_rrs": 3,
"probe_dns_query_succeeded": 1,
"probe_dns_retries": float64(test.Probe.Retries),
}

checkRegistryResults(expectedResults, mfs, t)
}
}
}

func TestDNSProtocol(t *testing.T) {
if os.Getenv("CI") == "true" {
t.Skip("skipping; CI is failing on ipv6 dns requests")
Expand Down Expand Up @@ -651,6 +742,7 @@ func TestDNSMetrics(t *testing.T) {
"probe_dns_authority_rrs": nil,
"probe_dns_additional_rrs": nil,
"probe_dns_query_succeeded": nil,
"probe_dns_retries": nil,
}

checkMetrics(expectedMetrics, mfs, t)
Expand Down