Skip to content

Commit 0e8917b

Browse files
committed
udp: impl endpoint-independent filtering
1 parent 07ccf30 commit 0e8917b

File tree

4 files changed

+135
-65
lines changed

4 files changed

+135
-65
lines changed

intra/netstack/udp.go

+86-42
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,20 @@ import (
2323
"gvisor.dev/gvisor/pkg/waiter"
2424
)
2525

26-
var errMissingEp = errors.New("not connected to any endpoint")
26+
var (
27+
errMissingEp = errors.New("not connected to any endpoint")
28+
errMissingReq = errors.New("missing forwarder request")
29+
errFilteredOut = errors.New("no eif; filtered out")
30+
)
31+
32+
type DemuxerFn func(dst netip.AddrPort) error
2733

2834
type GUDPConnHandler interface {
2935
// Proxy proxies data between conn (src) and dst.
3036
Proxy(conn *GUDPConn, src, dst netip.AddrPort) bool
31-
// ProxyMux proxies data between conn and multiple destinations.
32-
ProxyMux(conn *GUDPConn, src, dst netip.AddrPort) bool
37+
// ProxyMux proxies data between conn and multiple destinations
38+
// (endpoint-independent mapping).
39+
ProxyMux(conn *GUDPConn, src, dst netip.AddrPort, dmx DemuxerFn) bool
3340
// Error notes the error in connecting src to dst.
3441
Error(conn *GUDPConn, src, dst netip.AddrPort, err error)
3542
// CloseConns closes conns by ids, or all if ids is empty.
@@ -42,10 +49,17 @@ var _ core.UDPConn = (*GUDPConn)(nil)
4249

4350
type GUDPConn struct {
4451
stack *stack.Stack
45-
c *core.Volatile[*gonet.UDPConn] // conn exposes UDP semantics atop endpoint
46-
src netip.AddrPort // local addr (remote addr in netstack)
47-
dst netip.AddrPort // remote addr (local addr in netstack)
48-
req *udp.ForwarderRequest // egress request as UDP
52+
53+
// conn exposes UDP semantics atop endpoint
54+
c *core.Volatile[*gonet.UDPConn]
55+
// local addr (remote addr in netstack)
56+
// ex: 10.111.222.1:20716; same as endpoint.GetRemoteAddress
57+
src netip.AddrPort
58+
// remote addr (local addr in netstack)
59+
// ex: 10.111.222.3:53; same as endpoint.GetLocalAddress
60+
dst netip.AddrPort
61+
62+
req *udp.ForwarderRequest // egress request as UDP
4963

5064
eim bool // endpoint is muxed
5165
eif bool // endpoint is transparent
@@ -85,6 +99,21 @@ func udpForwarder(s *stack.Stack, h GUDPConnHandler) *udp.Forwarder {
8599
log.E("ns: udp: forwarder: nil request")
86100
return
87101
}
102+
103+
// owner app tun ns h
104+
// repr socket packet endpoint socket
105+
// type udp fd gudpconn core.minconn
106+
//
107+
// (src, dst) :1111, :53 :1111, :53 :53, :1111 :9999, :53
108+
//
109+
// write :1111 => :53 :1111, :53 :53 => :1111 :9999 => :53
110+
// \ /
111+
// \ /
112+
// (pipe) \ /
113+
// / \
114+
// / \
115+
// / \
116+
// read :1111 <= :53 :1111, :53 :53 <= :1111 :9999 <= :53
88117
id := req.ID()
89118
// src 10.111.222.1:20716; same as endpoint.GetRemoteAddress
90119
src := remoteAddrPort(id)
@@ -105,10 +134,30 @@ func udpForwarder(s *stack.Stack, h GUDPConnHandler) *udp.Forwarder {
105134
}
106135
}
107136

137+
demux := func(newdst netip.AddrPort) error {
138+
if newdst == dst {
139+
log.D("ns: udp: demuxer: no-op; src(%v) same as dst(%v)", src, newdst)
140+
return nil
141+
}
142+
if !gc.eif {
143+
return errFilteredOut
144+
}
145+
newgc := makeGUDPConn(s, nil /*not a forwarder req*/, src, newdst)
146+
if !settings.SingleThreaded.Load() {
147+
if err := newgc.Establish(); err != nil {
148+
log.E("ns: udp: demuxer: dial: %v; src(%v) dst(%v)", err, src, newdst)
149+
go h.Error(newgc, src, newdst, err)
150+
return err
151+
}
152+
}
153+
go h.Proxy(newgc, src, newdst)
154+
return nil
155+
}
156+
108157
// proxy in a separate gorountine; return immediately
109158
// why? netstack/dispatcher.go:newReadvDispatcher
110159
if gc.eim {
111-
go h.ProxyMux(gc, src, dst)
160+
go h.ProxyMux(gc, src, dst, demux)
112161
} else {
113162
go h.Proxy(gc, src, dst)
114163
}
@@ -124,47 +173,35 @@ func (g *GUDPConn) conn() *gonet.UDPConn {
124173
}
125174

126175
func (g *GUDPConn) StatefulTeardown() (fin bool) {
127-
_ = g.tryConnect() // establish circuit then teardown
128-
_ = g.Close() // then shutdown
129-
return true // always fin
176+
_ = g.Establish() // establish circuit then teardown
177+
_ = g.Close() // then shutdown
178+
return true // always fin
130179
}
131180

132181
func (g *GUDPConn) Establish() error {
133-
if g.eif {
134-
return g.tryBind()
135-
}
136-
return g.tryConnect()
137-
}
138-
139-
func (g *GUDPConn) tryConnect() error {
140-
if g.ok() { // already setup
141-
return nil
142-
}
143-
144-
wq := new(waiter.Queue)
145-
if endpoint, err := g.req.CreateEndpoint(wq); err != nil {
146-
// ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host)
147-
log.E("ns: udp: connect: endpoint for %v => %v; err(%v)", g.src, g.dst, err)
148-
return e(err)
149-
} else {
150-
g.c.Store(gonet.NewUDPConn(wq, endpoint))
151-
}
152-
return nil
153-
}
154-
155-
func (g *GUDPConn) tryBind() error {
156182
if g.ok() { // already setup
157183
return nil
158184
}
159185

160-
src, proto := addrport2nsaddr(g.src)
161-
// unconnected socket w/ gonet.DialUDP
162-
if conn, err := gonet.DialUDP(g.stack, &src, nil, proto); err != nil {
163-
log.E("ns: udp: bind: endpoint for %v [=> %v]; err(%v)", g.src, g.dst, err)
164-
return err
186+
if g.req == nil {
187+
src, proto := addrport2nsaddr(g.dst) // remote addr is local addr in netstack
188+
dst, _ := addrport2nsaddr(g.src) // local addr is remote addr in netstack
189+
// ingress socket w/ gonet.DialUDP
190+
if conn, err := gonet.DialUDP(g.stack, &src, &dst, proto); err != nil {
191+
log.E("ns: udp: dial: endpoint for %v => %v; err(%v)", g.src, g.dst, err)
192+
return err
193+
} else {
194+
g.c.Store(conn)
195+
}
165196
} else {
166-
// todo: handle the first pkt like in g.req.CreateEndpoint
167-
g.c.Store(conn)
197+
wq := new(waiter.Queue)
198+
if endpoint, err := g.req.CreateEndpoint(wq); err != nil {
199+
// ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host)
200+
log.E("ns: udp: connect: endpoint for %v => %v; err(%v)", g.src, g.dst, err)
201+
return e(err)
202+
} else {
203+
g.c.Store(gonet.NewUDPConn(wq, endpoint))
204+
}
168205
}
169206
return nil
170207
}
@@ -196,7 +233,14 @@ func (g *GUDPConn) Write(data []byte) (int, error) {
196233
// ep(state 3 / info &{2048 17 {53 10.111.222.3 17711 10.111.222.1} 1 10.111.222.3 1} / stats &{{{1}} {{0}} {{{0}} {{0}} {{0}} {{0}}} {{{0}} {{0}} {{0}}} {{{0}} {{0}}} {{{0}} {{0}} {{0}}}})
197234
// 3: status:datagram-connected / {2048=>proto, 17=>transport, {53=>local-port localip 17711=>remote-port remoteip}=>endpoint-id, 1=>bind-nic-id, ip=>bind-addr, 1=>registered-nic-id}
198235
// g.ep may be nil: log.V("ns: writeFrom: from(%v) / ep(state %v / info %v / stats %v)", addr, g.ep.State(), g.ep.Info(), g.ep.Stats())
199-
return c.Write(data)
236+
if g.eif {
237+
// unexpected except in cases of DNS override;
238+
// forward the packet to the dst as got from the first pkt
239+
log.W("ns: udp: Write(To): unexpected; %s <= %s; sz: %d", g.src, g.dst, len(data))
240+
return c.WriteTo(data, net.UDPAddrFromAddrPort(g.dst))
241+
} else {
242+
return c.Write(data)
243+
}
200244
}
201245
return 0, netError(g, "udp", "write", io.ErrClosedPipe)
202246
}

intra/tcp.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ const (
7373
)
7474

7575
const (
76-
retrytimeout = 1 * time.Minute
76+
retrytimeout = 15 * time.Second
7777
onFlowTimeout = 5 * time.Second
7878
)
7979

intra/udp.go

+36-21
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ package intra
2727

2828
import (
2929
"errors"
30-
"io"
3130
"net"
3231
"net/netip"
3332
"sync"
@@ -176,32 +175,48 @@ func (h *udpHandler) onFlow(localaddr, target netip.AddrPort, realips, domains,
176175
}
177176

178177
// ProxyMux implements netstack.GUDPConnHandler
179-
func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) {
178+
func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) {
180179
defer core.Recover(core.Exit11, "udp.ProxyMux")
181-
return h.proxy(gconn, src, dst, true)
180+
return h.proxy(gconn, src, dst, dmx)
182181
}
183182

184183
// Error implements netstack.GUDPConnHandler.
185184
// Must be called from a goroutine.
186-
func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, dst netip.AddrPort, err error) {
187-
ok := h.proxy(gconn, src, dst, false)
188-
log.I("udp: proxy: %v -> %v; err %v; recovered? %t", src, dst, err, ok)
185+
func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, target netip.AddrPort, err error) {
186+
log.W("udp: proxy: %v -> %v; err %v", src, target, err)
187+
if !src.IsValid() || !target.IsValid() {
188+
return
189+
}
190+
191+
realips, domains, probableDomains, blocklists := undoAlg(h.resolver, target.Addr())
192+
193+
// flow is alg/nat-aware, do not change target or any addrs
194+
res := h.onFlow(src, target, realips, domains, probableDomains, blocklists)
195+
cid, pid, uid := splitCidPidUid(res)
196+
smm := udpSummary(cid, pid, uid, target.Addr())
197+
198+
if h.status.Load() == UDPEND {
199+
err = errUdpEnd
200+
} else if pid == ipn.Block {
201+
err = errUdpFirewalled
202+
}
203+
smm.done(err)
189204
}
190205

191206
// Proxy implements netstack.GUDPConnHandler; thread-safe.
192207
// Must be called from a goroutine.
193208
func (h *udpHandler) Proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) {
194209
defer core.Recover(core.Exit11, "udp.Proxy")
195-
return h.proxy(gconn, src, dst, false)
210+
return h.proxy(gconn, src, dst, nil)
196211
}
197212

198213
// proxy connects src to dst over a proxy; thread-safe.
199-
func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, mux bool) (ok bool) {
200-
201-
remote, smm, ct, err := h.Connect(gconn, src, dst, mux) // remote may be nil; smm is never nil
214+
func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) {
215+
mux := dmx != nil
216+
remote, smm, err := h.Connect(gconn, src, dst, dmx) // remote may be nil; smm is never nil
202217

203218
if err != nil {
204-
clos(gconn, remote)
219+
core.Close(gconn, remote)
205220
queueSummary(h.smmch, h.done, smm.done(err)) // smm may be nil
206221
log.W("udp: proxy: mux? %t, unexpected %s -> %s; err: %v", mux, src, dst, err)
207222
// dst addrs no longer tracked in h.Connect: h.conntracker.Untrack(ct.CID)
@@ -217,23 +232,23 @@ func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, mu
217232
cid = smm.ID
218233
}
219234

220-
h.conntracker.Track(ct, gconn, remote)
235+
h.conntracker.Track(cid, gconn, remote)
221236
core.Go("udp.forward: "+cid, func() {
222-
defer h.conntracker.Untrack(ct.CID)
237+
defer h.conntracker.Untrack(cid)
223238
forward(gconn, &rwext{remote}, h.smmch, h.done, smm)
224239
})
225240
return true // ok
226241
}
227242

228243
// Connect connects the proxy server; thread-safe.
229-
func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, mux bool) (dst core.UDPConn, smm *SocketSummary, ct core.ConnTuple, err error) {
230-
var px ipn.Proxy = nil
231-
var pc io.Closer = nil
232-
233-
// connect gconn right away, since we assume a duplex-stream from here on
234-
// see: h.Connect -> dnsOverride
235-
if err = gconn.Establish(); err != nil {
236-
log.W("udp: %s gconn connect, mux? %t, err %s => %s", src, target, mux, err)
244+
func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, dmx netstack.DemuxerFn) (pc net.Conn, smm *SocketSummary, err error) {
245+
mux := dmx != nil
246+
247+
if !target.IsValid() { // must call h.Bind
248+
err = errUdpUnconnected
249+
} else { // connect gconn right away, since we assume a duplex-stream from here on
250+
// see: h.Connect -> dnsOverride
251+
err = gconn.Establish()
237252
} // err handled after onFlow, so that the listener knows about this gconn/flow
238253

239254
realips, domains, probableDomains, blocklists := undoAlg(h.resolver, target.Addr())

intra/udpmux.go

+12-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"github.com/celzero/firestack/intra/core"
2020
"github.com/celzero/firestack/intra/log"
21+
"github.com/celzero/firestack/intra/netstack"
2122
)
2223

2324
// from: github.com/pion/transport/blob/03c807b/udp/conn.go
@@ -61,7 +62,8 @@ type muxer struct {
6162
dxconns chan *demuxconn // never closed
6263
doneCh chan struct{} // stop vending, reading, and routing
6364
once sync.Once
64-
cb func() // muxer.stop() callback (new goroutine)
65+
cb func() // muxer.stop() callback (in a new goroutine)
66+
vnd netstack.DemuxerFn // for new routes in netstack
6567

6668
rmu sync.Mutex // protects routes
6769
routes map[string]*demuxconn // remote addr -> demuxed conn
@@ -249,6 +251,11 @@ func (x *muxer) route(raddr net.Addr) (*demuxconn, error) {
249251
case x.dxconns <- conn:
250252
x.stats.dxcount.Add(1)
251253
x.routes[addr] = conn
254+
if dst, err := addr2netip(raddr); err == nil && dst.IsValid() {
255+
go x.vnd(dst)
256+
} else { // should never happen
257+
log.E("udp: mux: %s route: invalid addr %s; err: %v", x.cid, raddr, err)
258+
}
252259
log.I("udp: mux: %s route: new for %s; stats: %d",
253260
x.cid, raddr, x.stats)
254261
}
@@ -488,3 +495,7 @@ func (e *muxTable) dissociate(id string, src netip.AddrPort) {
488495
defer e.Unlock()
489496
delete(e.t, src)
490497
}
498+
499+
func addr2netip(addr net.Addr) (netip.AddrPort, error) {
500+
return netip.ParseAddrPort(addr.String())
501+
}

0 commit comments

Comments
 (0)