Skip to content

Commit 86dff9c

Browse files
authored
Merge pull request #22 from owenthereal/remove_dns_tcp
Wait for udp/tcp server before shutting it down
2 parents b1258da + 2a49bfb commit 86dff9c

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

dns/server.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net"
7+
"sync"
78
"time"
89

910
"github.com/miekg/dns"
@@ -38,35 +39,43 @@ func (d *dnsServer) Run(ctx context.Context) error {
3839
mux.HandleFunc(tld+".", d.handleDNS)
3940
}
4041

41-
ctx, cancel := context.WithCancel(ctx)
4242
var g run.Group
4343
{
44+
var wg sync.WaitGroup
45+
wg.Add(1)
4446
udp := &dns.Server{
45-
Handler: mux,
46-
Addr: d.cfg.Addr,
47-
Net: "udp",
48-
ReusePort: true,
47+
Handler: mux,
48+
Addr: d.cfg.Addr,
49+
Net: "udp",
4950
}
5051
g.Add(func() error {
52+
wg.Done()
5153
return udp.ListenAndServe()
5254
}, func(err error) {
55+
// Wait for udp server before shutting it down
56+
wg.Wait()
5357
_ = udp.ShutdownContext(ctx)
5458
})
5559
}
5660
{
61+
var wg sync.WaitGroup
62+
wg.Add(1)
5763
tcp := &dns.Server{
58-
Handler: mux,
59-
Addr: d.cfg.Addr,
60-
Net: "tcp",
61-
ReusePort: true,
64+
Handler: mux,
65+
Addr: d.cfg.Addr,
66+
Net: "tcp",
6267
}
6368
g.Add(func() error {
69+
wg.Done()
6470
return tcp.ListenAndServe()
6571
}, func(err error) {
72+
// Wait for tcp server before shutting it down
73+
wg.Wait()
6674
_ = tcp.ShutdownContext(ctx)
6775
})
6876
}
6977
{
78+
ctx, cancel := context.WithCancel(ctx)
7079
g.Add(func() error {
7180
<-ctx.Done()
7281
return ctx.Err()

server/server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func Test_Server(t *testing.T) {
176176
}
177177

178178
select {
179-
case <-time.After(20 * time.Second):
179+
case <-time.After(5 * time.Second):
180180
t.Fatal("error wait time out")
181181
case err := <-errch:
182182
if want, got := fmt.Sprintf("host root %s was removed", hostRoot), err.Error(); want != got {
@@ -259,7 +259,7 @@ func Test_Server_Shutdown(t *testing.T) {
259259
}()
260260

261261
select {
262-
case <-time.After(5 * time.Second):
262+
case <-time.After(10 * time.Second):
263263
t.Fatal("error wait time out")
264264
case err := <-errch:
265265
if want, got := c.WantErrMsg, err.Error(); !strings.Contains(got, want) {

0 commit comments

Comments
 (0)