@@ -28,45 +28,62 @@ type Rule struct {
28
28
29
29
func main () {
30
30
var (
31
- file string
31
+ file string
32
+ verbose bool
33
+ dialTimeout time.Duration
32
34
)
33
35
34
36
flag .StringVar (& file , "f" , "" , "Job to run or leave blank for job.yaml in current directory" )
35
-
37
+ flag .BoolVar (& verbose , "v" , true , "Verbose output for opened and closed connections" )
38
+ flag .DurationVar (& dialTimeout , "t" , time .Millisecond * 1500 , "Dial timeout" )
36
39
flag .Parse ()
37
40
41
+ if len (file ) == 0 {
42
+ fmt .Fprintf (os .Stderr , "usage: mixctl -f rules.yaml\n " )
43
+ os .Exit (1 )
44
+ }
45
+
38
46
set := ForwardingSet {}
39
47
data , err := os .ReadFile (file )
40
48
if err != nil {
41
- log .Fatalf ("error reading file %s %s" , file , err .Error ())
49
+ fmt .Fprintf (os .Stderr , "error reading file %s %s" , file , err .Error ())
50
+ os .Exit (1 )
42
51
}
43
52
if err = yaml .Unmarshal (data , & set ); err != nil {
44
- log .Fatalf ("error parsing file %s %s" , file , err .Error ())
53
+ fmt .Fprintf (os .Stderr , "error parsing file %s %s" , file , err .Error ())
54
+ os .Exit (1 )
55
+ }
56
+
57
+ if len (set .Rules ) == 0 {
58
+ fmt .Fprintf (os .Stderr , "no rules found in file %s" , file )
59
+ os .Exit (1 )
45
60
}
46
61
47
- fmt .Printf ("mixctl by inlets.. \n " )
62
+ fmt .Printf ("Starting mixctl by https:// inlets.dev/ \n \n " )
48
63
49
64
wg := sync.WaitGroup {}
50
65
wg .Add (len (set .Rules ))
51
- for _ , f := range set .Rules {
52
-
53
- r := f
54
- go func (rule * Rule ) {
55
- fmt .Printf ("Forward (%s) from: %s to: %s\n " , rule .Name , rule .From , rule .To )
66
+ for _ , rule := range set .Rules {
67
+ fmt .Printf ("Forward (%s) from: %s to: %s\n " , rule .Name , rule .From , rule .To )
68
+ }
69
+ fmt .Println ()
56
70
57
- if err := forward (rule .Name , rule .From , rule .To ); err != nil {
71
+ for _ , rule := range set .Rules {
72
+ // Copy the value to avoid the loop variable being reused
73
+ r := rule
74
+ go func () {
75
+ if err := forward (r .Name , r .From , r .To , verbose , dialTimeout ); err != nil {
58
76
log .Printf ("error forwarding %s" , err .Error ())
59
77
os .Exit (1 )
60
78
}
61
-
62
79
defer wg .Done ()
63
- }(& r )
80
+ }()
64
81
}
65
- wg .Wait ()
66
82
83
+ wg .Wait ()
67
84
}
68
85
69
- func forward (name , from string , to []string ) error {
86
+ func forward (name , from string , to []string , verbose bool , dialTimeout time. Duration ) error {
70
87
seed := time .Now ().UnixNano ()
71
88
rand .Seed (seed )
72
89
@@ -76,42 +93,92 @@ func forward(name, from string, to []string) error {
76
93
return fmt .Errorf ("error listening on %s %s" , from , err .Error ())
77
94
}
78
95
96
+ defer l .Close ()
97
+
79
98
for {
80
- conn , err := l .Accept ()
99
+ // accept a connection on the local port of the load balancer
100
+ local , err := l .Accept ()
81
101
if err != nil {
82
102
return fmt .Errorf ("error accepting connection %s" , err .Error ())
83
103
}
84
104
105
+ // pick randomly from the list of upstream servers
106
+ // available
85
107
index := rand .Intn (len (to ))
108
+ upstream := to [index ]
86
109
87
- remote , err := net .Dial ("tcp" , to [index ])
88
- if err != nil {
89
- return fmt .Errorf ("error dialing %s %s" , to [index ], err .Error ())
90
- }
110
+ // A separate Goroutine means the loop can accept another
111
+ // incoming connection on the local address
112
+ go connect (local , upstream , from , verbose , dialTimeout )
113
+ }
114
+ }
91
115
92
- go func () {
93
- log .Printf ("[%s] %s => %s" ,
94
- from ,
95
- conn .RemoteAddr ().String (),
96
- remote .RemoteAddr ().String ())
97
- if err := forwardConnection (conn , remote ); err != nil && err .Error () != "done" {
98
- log .Printf ("error forwarding connection %s" , err .Error ())
99
- }
100
- }()
116
+ // connect dials the upstream address, then copies data
117
+ // between it and connection accepted on a local port
118
+ func connect (local net.Conn , upstreamAddr , from string , verbose bool , dialTimeout time.Duration ) {
119
+ defer local .Close ()
120
+
121
+ // If Dial is used on its own, then the timeout can be as long
122
+ // as 2 minutes on MacOS for an unreachable host
123
+ upstream , err := net .DialTimeout ("tcp" , upstreamAddr , dialTimeout )
124
+ if err != nil {
125
+ log .Printf ("error dialing %s %s" , upstreamAddr , err .Error ())
126
+ return
127
+ }
128
+ defer upstream .Close ()
129
+
130
+ if verbose {
131
+ log .Printf ("Connected %s => %s (%s)" ,
132
+ from ,
133
+ upstream .RemoteAddr ().String (),
134
+ local .RemoteAddr ().String ())
135
+ }
136
+
137
+ ctx := context .Background ()
138
+ if err := copy (ctx , local , upstream ); err != nil && err .Error () != "done" {
139
+ log .Printf ("error forwarding connection %s" , err .Error ())
140
+ }
141
+
142
+ if verbose {
143
+ log .Printf ("Closed %s => %s (%s)" ,
144
+ from ,
145
+ upstream .RemoteAddr ().String (),
146
+ local .RemoteAddr ().String ())
101
147
}
102
148
}
103
149
104
- func forwardConnection (from , to net.Conn ) error {
105
- errgrp , _ := errgroup .WithContext (context .Background ())
150
+ // copy copies data between two connections using io.Copy
151
+ // and will exit when either connection is closed or runs
152
+ // into an error
153
+ func copy (ctx context.Context , from , to net.Conn ) error {
154
+
155
+ ctx , cancel := context .WithCancel (ctx )
156
+ errgrp , _ := errgroup .WithContext (ctx )
106
157
errgrp .Go (func () error {
107
158
io .Copy (from , to )
159
+ cancel ()
108
160
109
161
return fmt .Errorf ("done" )
110
162
})
111
163
errgrp .Go (func () error {
112
164
io .Copy (to , from )
165
+ cancel ()
166
+
113
167
return fmt .Errorf ("done" )
114
168
})
169
+ errgrp .Go (func () error {
170
+ <- ctx .Done ()
171
+
172
+ // This closes both ends of the connection as
173
+ // soon as possible.
174
+ from .Close ()
175
+ to .Close ()
176
+ return fmt .Errorf ("done" )
177
+ })
178
+
179
+ if err := errgrp .Wait (); err != nil {
180
+ return err
181
+ }
115
182
116
- return errgrp . Wait ()
183
+ return nil
117
184
}
0 commit comments