Skip to content

Commit

Permalink
ipn/wg: merge amnezia and bepass
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Nov 17, 2024
1 parent 757269f commit f87b287
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 80 deletions.
25 changes: 15 additions & 10 deletions intra/ipn/wg/amnezia.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func (a *Amnezia) String() string {
if a == nil {
return "<nil>"
}
if !a.Set() {
return "<unset>"
}
return fmt.Sprintf("%s: amnezia: jc(%d), jmin(%d), jmax(%d), s1(%d), s2(%d), h1(%d), h2(%d), h3(%d), h4(%d)",
a.id, a.Jc, a.Jmin, a.Jmax, a.S1, a.S2, a.H1, a.H2, a.H3, a.H4)
}
Expand Down Expand Up @@ -113,16 +116,16 @@ func (a *Amnezia) recv(pktptr *[]byte) (ok bool) {
pkt, typ = a.strip(pkt)

switch typ {
case a.H1:
case device.MessageInitiationType, a.H1:
typ = device.MessageInitiationType
binary.LittleEndian.PutUint32(pkt[:h], device.MessageInitiationType)
case a.H2:
case device.MessageResponseType, a.H2:
typ = device.MessageResponseType
binary.LittleEndian.PutUint32(pkt[:h], device.MessageResponseType)
case a.H3:
case device.MessageCookieReplyType, a.H3:
typ = device.MessageCookieReplyType
binary.LittleEndian.PutUint32(pkt[:h], device.MessageCookieReplyType)
case a.H4:
case device.MessageTransportType, a.H4: // must be default?
typ = device.MessageTransportType
binary.LittleEndian.PutUint32(pkt[:h], device.MessageTransportType)
}
Expand Down Expand Up @@ -192,7 +195,9 @@ func (a *Amnezia) instate(pkt []byte) ([]byte, uint32) {
func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
size := uint16(len(pkt))
h := uint16(device.MessageTransportOffsetReceiver)
defaultType := binary.LittleEndian.Uint32(pkt[:h])
// assume the correct msg type is in just the first byte:
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
defaultType := uint8(pkt[0])

var discard uint16 = 0
var possibleType uint32 = 0
Expand All @@ -211,13 +216,13 @@ func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {

if maybeStrip {
hdr := pkt[discard : discard+h]
strippedType := binary.LittleEndian.Uint32(hdr)
if strippedType == possibleType {
return pkt[discard:], strippedType
} // else: sizes match but msg types do not
obsType := binary.LittleEndian.Uint32(hdr)
if obsType == possibleType {
return pkt[discard:], obsType
} // else: msg type mismatch, but size matched
} // else: nothing to discard

return pkt, defaultType
return pkt, uint32(defaultType)
}

func (a *Amnezia) logIfNeeded(dir string, typ uint32, n int) {
Expand Down
53 changes: 7 additions & 46 deletions intra/ipn/wg/wgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,8 @@ type StdNetBind struct {
connect connector
mh *multihost.MH

reserved []byte // overwrite the 3 wg reserved bytes
overwriteReserve bool
amnezia *Amnezia
floodBa *core.Barrier[int, netip.AddrPort]
amnezia *Amnezia
floodBa *core.Barrier[int, netip.AddrPort]

mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
Expand All @@ -118,18 +116,16 @@ type StdNetBind struct {
}

// TODO: get d, ep, f, rb through an Opts bag?
func NewEndpoint(id string, d connector, ep *multihost.MH, f rwobserver, a *Amnezia, rb [3]byte) *StdNetBind {
func NewEndpoint(id string, d connector, ep *multihost.MH, f rwobserver, a *Amnezia) *StdNetBind {
s := &StdNetBind{
id: id,
connect: d,
mh: ep,
observer: f,
amnezia: a,
reserved: rb[:3], // github.com/bepass-org/warp-plus/blob/19ac233cc6/wiresocks/config.go#L184
floodBa: core.NewKeyedBarrier[int, netip.AddrPort](minFloodInterval),
sendAddr: core.NewZeroVolatile[netip.AddrPort](),
}
s.overwriteReserve = a.Set() || isReservedOverwitten(s.reserved)
return s
}

Expand Down Expand Up @@ -326,15 +322,7 @@ func (s *StdNetBind) makeReceiveFn(uc *net.UDPConn) conn.ReceiveFunc {
extend(uc, wgtimeout)
n, addr, err := uc.ReadFromUDPAddrPort(b)
if err == nil {
if isReservedOverwitten(b) {
if s.amnezia.Set() {
recvOverwritten = s.amnezia.recv(&b)
} else if n > 3 && isWgMsgType(b[0]) && recvOverwritten {
// github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138
copy(b[1:4], reservedZeros)
recvOverwritten = true
}
}
recvOverwritten = s.amnezia.recv(&b)
numMsgs++
}

Expand All @@ -344,7 +332,7 @@ func (s *StdNetBind) makeReceiveFn(uc *net.UDPConn) conn.ReceiveFunc {
}

s := fmt.Sprintf("wg: bind: %s recvFrom(%v): %d / ov? %t<=%t / err? %v",
s.id, addr, n, s.overwriteReserve, recvOverwritten, err)
s.id, addr, n, s.amnezia.Set(), recvOverwritten, err)
if err == nil || timedout(err) {
log.V(s)
} else {
Expand Down Expand Up @@ -405,18 +393,9 @@ func (s *StdNetBind) Send(buf [][]byte, peer conn.Endpoint) (err error) {

datalen := len(data) // grab the length before we overwrite it

if s.overwriteReserve {
if s.amnezia.Set() {
overwritten = s.amnezia.send(&data)
} else if datalen > 3 && isWgMsgType(data[0]) {
// overwrite the 3 reserved bytes on non-random packets
// from: github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/peer.go#L138
copy(data[1:4], s.reserved)
overwritten = true
}
}
overwritten = s.amnezia.send(&data)

if !flooded && (experimentalWg || s.overwriteReserve) {
if !flooded && (experimentalWg || s.amnezia.Set()) {
if datalen == device.MessageInitiationSize {
s.flood(uc, dst, fkHandshake) // was probably a handshake
flooded = true
Expand All @@ -440,24 +419,6 @@ func (s *StdNetBind) Send(buf [][]byte, peer conn.Endpoint) (err error) {
return err
}

// github.com/WireGuard/wireguard-go/blob/12269c2761/device/send.go#L456
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
func isWgMsgType(x byte) bool {
// 1: MsgInitiation, 2: MsgResponse, 3: MsgCookieReply, 4: MsgTransport
// blog.cloudflare.com/warp-technical-challenges/
// Handshakes have to be performed every two minutes to rotate keys making
// them insufficiently persistent. We could have forked the protocol to add
// any number of additional fields, but it is important to us to remain wire
// compatible with other WireGuard clients. Fortunately, WireGuard has a three
// byte block in its header which is not currently used by other clients.
// We decided to put our identifier in this region and still support messages
// from other WireGuard clients (albeit with less reliable routing than we can
// offer).
// Though the open source Cloudflare WARP boring-tun impl does not do so:
// github.com/cloudflare/boringtun/blob/64a2fc7c63/boringtun/src/noise/handshake.rs#L734
return x >= device.MessageInitiationType && x <= device.MessageTransportType
}

// flood c with random-sized, non-sense (unencrypted) packets.
// this is okay to do because wireguard silently drops packets that won't decrypt.
// github.com/WireGuard/wireguard-go/blob/19ac233cc6/wireguard/device/send.go#L96
Expand Down
54 changes: 30 additions & 24 deletions intra/ipn/wgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ package ipn

import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"fmt"
"net"
"net/netip"
Expand Down Expand Up @@ -77,7 +77,6 @@ type wgifopts struct {
peers map[string]device.NoisePublicKey
dns, ep *multihost.MH
mtu int
clientid [3]byte
amnezia *wg.Amnezia
}

Expand All @@ -91,7 +90,6 @@ type wgtun struct {
ingress chan *buffer.View // pipes ep writes to wg
events chan tun.Event // wg specific tun (interface) events
amnezia *wg.Amnezia // amnezia config, if any
clientid [3]byte // client id; applicable only for warp
finalize chan struct{} // close signal for incomingPacket
once sync.Once // closer fn; exec exactly once
preferOffload bool // UDP GRO/GSO offloads
Expand Down Expand Up @@ -334,11 +332,6 @@ func (w *wgproxy) update(id, txt string) bool {
return anew
}

if !bytes.Equal(opts.clientid[:], w.clientid[:]) {
log.D("proxy: wg: !update(%s): clientid %v != %v", w.id, opts.clientid, w.clientid)
return anew
}

if err := w.setRoutes(opts.ifaddrs); err != nil {
log.W("proxy: wg: !update(%s): setRoutes: %v", w.id, err)
return anew
Expand Down Expand Up @@ -430,15 +423,6 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) {
if opts.mtu, err = strconv.Atoi(v); err != nil {
return
}
case "client_id":
// only for warp: blog.cloudflare.com/warp-technical-challenges
// When we begin a WireGuard session we include our clientid field
// which is provided by our authentication server which has to be
// communicated with to begin a WARP session.
if b, err := base64.StdEncoding.DecodeString(v); err == nil {
n := copy(opts.clientid[:], b)
log.D("proxy: wg: %s ifconfig: clientid(%d) %v", id, n, opts.clientid)
}
case "allowed_ip": // may exist more than once
if err = loadIPNets(&opts.allowed, v); err != nil {
return
Expand Down Expand Up @@ -473,6 +457,31 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) {
// peer config: carry over public keys
log.D("proxy: wg: %s ifconfig: processing key %q, err? %v", id, k, exx)
pcfg.WriteString(line + "\n")
case "client_id":
// only for warp: blog.cloudflare.com/warp-technical-challenges
// When we begin a WireGuard session we include our clientid field
// which is provided by our authentication server which has to be
// communicated with to begin a WARP session.
// Though the open source Cloudflare WARP boring-tun impl does not do so:
// github.com/cloudflare/boringtun/blob/64a2fc7c63/boringtun/src/noise/handshake.rs#L734
if b, err := base64.StdEncoding.DecodeString(v); err == nil && len(b) == 3 {
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/send.go#L456
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
h1 := append([]byte{device.MessageInitiationType}, b...)
h2 := append([]byte{device.MessageResponseType}, b...)
h3 := append([]byte{device.MessageCookieReplyType}, b...)
h4 := append([]byte{device.MessageTransportType}, b...)
// overwrite the 3 reserved bytes on all packets
// github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138
opts.amnezia.H1 = binary.LittleEndian.Uint32(h1)
opts.amnezia.H2 = binary.LittleEndian.Uint32(h2)
opts.amnezia.H3 = binary.LittleEndian.Uint32(h3)
opts.amnezia.H4 = binary.LittleEndian.Uint32(h4)
log.D("proxy: wg: %s ifconfig: clientid(%d) %v", id, len(b), b)
} else {
log.W("proxy: wg: %s ifconfig: clientid(%v) %d == 3?; err: %v",
id, v, len(b), err)
}
case "jc":
// github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/uapi.go#L286
jc, _ := strconv.Atoi(v)
Expand Down Expand Up @@ -506,9 +515,7 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) {
pcfg.WriteString(line + "\n")
}
}
if opts.amnezia.Set() {
log.I("proxy: wg: %s amnezia: %s", id, opts.amnezia)
}
log.D("proxy: wg: %s amnezia: %s", id, opts.amnezia)
*txtptr = pcfg.String()
if err == nil && len(opts.ifaddrs) <= 0 || opts.dns.Len() <= 0 || opts.mtu <= 0 {
err = errProxyConfig
Expand Down Expand Up @@ -574,7 +581,7 @@ func NewWgProxy(id string, ctl protect.Controller, rev netstack.GConnHandler, cf
// todo: use wgtun.serve fn instead of ctl
wgep = wg.NewEndpoint2(id, ctl, opts.ep, wgtun.listener)
} else {
wgep = wg.NewEndpoint(id, wgtun.serve, opts.ep, wgtun.listener, wgtun.amnezia, wgtun.clientid)
wgep = wg.NewEndpoint(id, wgtun.serve, opts.ep, wgtun.listener, wgtun.amnezia)
}

wgdev := device.NewDevice(wgtun, wgep, wglogger(id))
Expand Down Expand Up @@ -648,7 +655,6 @@ func makeWgTun(id, cfg string, ctl protect.Controller, rev netstack.GConnHandler
rt: x.NewIpTree(), // must be set to allowedaddrs
ba: core.NewBarrier[[]netip.Addr](wgbarrierttl),
amnezia: ifopts.amnezia,
clientid: ifopts.clientid,
status: core.NewVolatile(TUP),
preferOffload: preferOffload(id),
refreshBa: core.NewBarrier[bool](2 * time.Minute),
Expand Down Expand Up @@ -678,8 +684,8 @@ func makeWgTun(id, cfg string, ctl protect.Controller, rev netstack.GConnHandler
t.events <- tun.EventUp

if4, if6 := netstack.StackAddrs(s, wgnic)
log.I("proxy: wg: %s tun: created; dns[%s]; dst[%s]; mtu[%d]; ifaddrs[%v / %v]; clientid[%v]; amnezia[%t]",
t.id, ifopts.dns, ifopts.ep, tunmtu, if4, if6, ifopts.clientid, ifopts.amnezia.Set())
log.I("proxy: wg: %s tun: created; dns[%s]; dst[%s]; mtu[%d]; ifaddrs[%v / %v]; amnezia[%t]",
t.id, ifopts.dns, ifopts.ep, tunmtu, if4, if6, ifopts.amnezia.Set())

return t, nil
}
Expand Down

2 comments on commit f87b287

@ignoramous
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ignoramous
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.