Skip to content

Commit 3bbd128

Browse files
committedSep 11, 2022
Add a shorter timeout and close connections earlier
Signed-off-by: Alex Ellis (OpenFaaS Ltd) <alexellis2@gmail.com>
1 parent 7ff9d48 commit 3bbd128

File tree

3 files changed

+108
-32
lines changed

3 files changed

+108
-32
lines changed
 

‎README.md

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ To make the upstream address listen on all interfaces, use `0.0.0.0` instead of
4343

4444
The port for the from and to addresses do not need to match.
4545

46+
See also:
47+
* `-t` - specify the dial timeout for an upstream host in the "to" field of the config file.
48+
* `-v` - verbose logging - set to false to turn off logs of connections established and closed.
49+
4650
## License
4751

4852
This software is licensed MIT.

‎main.go

+99-32
Original file line numberDiff line numberDiff line change
@@ -28,45 +28,62 @@ type Rule struct {
2828

2929
func main() {
3030
var (
31-
file string
31+
file string
32+
verbose bool
33+
dialTimeout time.Duration
3234
)
3335

3436
flag.StringVar(&file, "f", "", "Job to run or leave blank for job.yaml in current directory")
35-
37+
flag.BoolVar(&verbose, "v", true, "Verbose output for opened and closed connections")
38+
flag.DurationVar(&dialTimeout, "t", time.Millisecond*1500, "Dial timeout")
3639
flag.Parse()
3740

41+
if len(file) == 0 {
42+
fmt.Fprintf(os.Stderr, "usage: mixctl -f rules.yaml\n")
43+
os.Exit(1)
44+
}
45+
3846
set := ForwardingSet{}
3947
data, err := os.ReadFile(file)
4048
if err != nil {
41-
log.Fatalf("error reading file %s %s", file, err.Error())
49+
fmt.Fprintf(os.Stderr, "error reading file %s %s", file, err.Error())
50+
os.Exit(1)
4251
}
4352
if err = yaml.Unmarshal(data, &set); err != nil {
44-
log.Fatalf("error parsing file %s %s", file, err.Error())
53+
fmt.Fprintf(os.Stderr, "error parsing file %s %s", file, err.Error())
54+
os.Exit(1)
55+
}
56+
57+
if len(set.Rules) == 0 {
58+
fmt.Fprintf(os.Stderr, "no rules found in file %s", file)
59+
os.Exit(1)
4560
}
4661

47-
fmt.Printf("mixctl by inlets..\n")
62+
fmt.Printf("Starting mixctl by https://inlets.dev/\n\n")
4863

4964
wg := sync.WaitGroup{}
5065
wg.Add(len(set.Rules))
51-
for _, f := range set.Rules {
52-
53-
r := f
54-
go func(rule *Rule) {
55-
fmt.Printf("Forward (%s) from: %s to: %s\n", rule.Name, rule.From, rule.To)
66+
for _, rule := range set.Rules {
67+
fmt.Printf("Forward (%s) from: %s to: %s\n", rule.Name, rule.From, rule.To)
68+
}
69+
fmt.Println()
5670

57-
if err := forward(rule.Name, rule.From, rule.To); err != nil {
71+
for _, rule := range set.Rules {
72+
// Copy the value to avoid the loop variable being reused
73+
r := rule
74+
go func() {
75+
if err := forward(r.Name, r.From, r.To, verbose, dialTimeout); err != nil {
5876
log.Printf("error forwarding %s", err.Error())
5977
os.Exit(1)
6078
}
61-
6279
defer wg.Done()
63-
}(&r)
80+
}()
6481
}
65-
wg.Wait()
6682

83+
wg.Wait()
6784
}
6885

69-
func forward(name, from string, to []string) error {
86+
func forward(name, from string, to []string, verbose bool, dialTimeout time.Duration) error {
7087
seed := time.Now().UnixNano()
7188
rand.Seed(seed)
7289

@@ -76,42 +93,92 @@ func forward(name, from string, to []string) error {
7693
return fmt.Errorf("error listening on %s %s", from, err.Error())
7794
}
7895

96+
defer l.Close()
97+
7998
for {
80-
conn, err := l.Accept()
99+
// accept a connection on the local port of the load balancer
100+
local, err := l.Accept()
81101
if err != nil {
82102
return fmt.Errorf("error accepting connection %s", err.Error())
83103
}
84104

105+
// pick randomly from the list of upstream servers
106+
// available
85107
index := rand.Intn(len(to))
108+
upstream := to[index]
86109

87-
remote, err := net.Dial("tcp", to[index])
88-
if err != nil {
89-
return fmt.Errorf("error dialing %s %s", to[index], err.Error())
90-
}
110+
// A separate Goroutine means the loop can accept another
111+
// incoming connection on the local address
112+
go connect(local, upstream, from, verbose, dialTimeout)
113+
}
114+
}
91115

92-
go func() {
93-
log.Printf("[%s] %s => %s",
94-
from,
95-
conn.RemoteAddr().String(),
96-
remote.RemoteAddr().String())
97-
if err := forwardConnection(conn, remote); err != nil && err.Error() != "done" {
98-
log.Printf("error forwarding connection %s", err.Error())
99-
}
100-
}()
116+
// connect dials the upstream address, then copies data
117+
// between it and connection accepted on a local port
118+
func connect(local net.Conn, upstreamAddr, from string, verbose bool, dialTimeout time.Duration) {
119+
defer local.Close()
120+
121+
// If Dial is used on its own, then the timeout can be as long
122+
// as 2 minutes on MacOS for an unreachable host
123+
upstream, err := net.DialTimeout("tcp", upstreamAddr, dialTimeout)
124+
if err != nil {
125+
log.Printf("error dialing %s %s", upstreamAddr, err.Error())
126+
return
127+
}
128+
defer upstream.Close()
129+
130+
if verbose {
131+
log.Printf("Connected %s => %s (%s)",
132+
from,
133+
upstream.RemoteAddr().String(),
134+
local.RemoteAddr().String())
135+
}
136+
137+
ctx := context.Background()
138+
if err := copy(ctx, local, upstream); err != nil && err.Error() != "done" {
139+
log.Printf("error forwarding connection %s", err.Error())
140+
}
141+
142+
if verbose {
143+
log.Printf("Closed %s => %s (%s)",
144+
from,
145+
upstream.RemoteAddr().String(),
146+
local.RemoteAddr().String())
101147
}
102148
}
103149

104-
func forwardConnection(from, to net.Conn) error {
105-
errgrp, _ := errgroup.WithContext(context.Background())
150+
// copy copies data between two connections using io.Copy
151+
// and will exit when either connection is closed or runs
152+
// into an error
153+
func copy(ctx context.Context, from, to net.Conn) error {
154+
155+
ctx, cancel := context.WithCancel(ctx)
156+
errgrp, _ := errgroup.WithContext(ctx)
106157
errgrp.Go(func() error {
107158
io.Copy(from, to)
159+
cancel()
108160

109161
return fmt.Errorf("done")
110162
})
111163
errgrp.Go(func() error {
112164
io.Copy(to, from)
165+
cancel()
166+
113167
return fmt.Errorf("done")
114168
})
169+
errgrp.Go(func() error {
170+
<-ctx.Done()
171+
172+
// This closes both ends of the connection as
173+
// soon as possible.
174+
from.Close()
175+
to.Close()
176+
return fmt.Errorf("done")
177+
})
178+
179+
if err := errgrp.Wait(); err != nil {
180+
return err
181+
}
115182

116-
return errgrp.Wait()
183+
return nil
117184
}

‎rules.example.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@ rules:
1414
- 192.168.1.19:22
1515
- 192.168.1.21:22
1616
- 192.168.1.20:22
17+
18+
- name: remap-local-ssh-port
19+
from: 127.0.0.1:2222
20+
to:
21+
- 127.0.0.1:22

0 commit comments

Comments
 (0)
Please sign in to comment.