Skip to content

Commit

Permalink
Use LatestHandshake to validate endpoint (#149)
Browse files Browse the repository at this point in the history
* wireguard: `wg show iface dump` reader and parser

* mesh: use LatestHandshake to validate NAT Endpoints

* add skip on error

* switch to loop parsing

So the stop on error pattern can be used

* Add error handling to ParseDump
  • Loading branch information
JulienVdG authored Jul 6, 2021
1 parent 0733c83 commit e12b502
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 47 deletions.
17 changes: 10 additions & 7 deletions pkg/mesh/mesh.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,18 @@ func (m *Mesh) applyTopology() {
return
}
// Find the old configuration.
oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
oldConfDump, err := wireguard.ShowDump(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf, err := wireguard.ParseDump(oldConfDump)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf := wireguard.Parse(oldConfRaw)
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
Expand Down Expand Up @@ -782,17 +787,15 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
// Should check location leader but only available in topology ... or have topology handle that list
// Better check wg latest-handshake
if !n.Endpoint.Equal(peer.Endpoint, false) {
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake)
if (peer.LatestHandshake != time.Time{}) {
natEndpoints[string(n.Key)] = peer.Endpoint
}
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
if !p.Endpoint.Equal(peer.Endpoint, false) {
if (peer.LatestHandshake != time.Time{}) {
natEndpoints[string(p.PublicKey)] = peer.Endpoint
}
}
Expand Down
249 changes: 209 additions & 40 deletions pkg/wireguard/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package wireguard
import (
"bufio"
"bytes"
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"time"

"k8s.io/apimachinery/pkg/util/validation"
)
Expand All @@ -31,6 +33,9 @@ type key string

const (
separator = "="
dumpSeparator = "\t"
dumpNone = "(none)"
dumpOff = "off"
interfaceSection section = "Interface"
peerSection section = "Peer"
listenPortKey key = "ListenPort"
Expand All @@ -42,6 +47,30 @@ const (
publicKeyKey key = "PublicKey"
)

type dumpInterfaceIndex int

const (
dumpInterfacePrivateKeyIndex = iota
dumpInterfacePublicKeyIndex
dumpInterfaceListenPortIndex
dumpInterfaceFWMarkIndex
dumpInterfaceLen
)

type dumpPeerIndex int

const (
dumpPeerPublicKeyIndex = iota
dumpPeerPresharedKeyIndex
dumpPeerEndpointIndex
dumpPeerAllowedIPsIndex
dumpPeerLatestHandshakeIndex
dumpPeerTransferRXIndex
dumpPeerTransferTXIndex
dumpPeerPersistentKeepaliveIndex
dumpPeerLen
)

// Conf represents a WireGuard configuration file.
type Conf struct {
Interface *Interface
Expand All @@ -61,6 +90,8 @@ type Peer struct {
PersistentKeepalive int
PresharedKey []byte
PublicKey []byte
// The following fields are part of the runtime information, not the configuration.
LatestHandshake time.Time
}

// DeduplicateIPs eliminates duplicate allowed IPs.
Expand Down Expand Up @@ -146,13 +177,11 @@ func (d DNSOrIP) String() string {
func Parse(buf []byte) *Conf {
var (
active section
ai *net.IPNet
kv []string
c Conf
err error
iface *Interface
i int
ip, ip4 net.IP
k key
line, v string
peer *Peer
Expand Down Expand Up @@ -205,49 +234,15 @@ func Parse(buf []byte) *Conf {
case peerSection:
switch k {
case allowedIPsKey:
// Reuse string slice.
kv = strings.Split(v, ",")
for i = range kv {
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
if err != nil {
continue
}
if ip4 = ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
}
ai.IP = ip
peer.AllowedIPs = append(peer.AllowedIPs, ai)
}
case endpointKey:
// Reuse string slice.
kv = strings.Split(v, ":")
if len(kv) < 2 {
err = peer.parseAllowedIPs(v)
if err != nil {
continue
}
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
case endpointKey:
err = peer.parseEndpoint(v)
if err != nil {
continue
}
d := DNSOrIP{}
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
if ip == nil {
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
continue
}
d.DNS = kv[0]
} else {
if ip4 = ip.To4(); ip4 != nil {
d.IP = ip4
} else {
d.IP = ip.To16()
}
}
peer.Endpoint = &Endpoint{
DNSOrIP: d,
Port: uint32(port),
}
case persistentKeepaliveKey:
i, err = strconv.Atoi(v)
if err != nil {
Expand Down Expand Up @@ -448,3 +443,177 @@ func writeKey(buf *bytes.Buffer, k key) error {
_, err = buf.WriteString(" = ")
return err
}

var (
errParseEndpoint = errors.New("could not parse Endpoint")
)

func (p *Peer) parseEndpoint(v string) error {
var (
kv []string
err error
ip, ip4 net.IP
port uint64
)
kv = strings.Split(v, ":")
if len(kv) < 2 {
return errParseEndpoint
}
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
if err != nil {
return err
}
d := DNSOrIP{}
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
if ip == nil {
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
return errParseEndpoint
}
d.DNS = kv[0]
} else {
if ip4 = ip.To4(); ip4 != nil {
d.IP = ip4
} else {
d.IP = ip.To16()
}
}

p.Endpoint = &Endpoint{
DNSOrIP: d,
Port: uint32(port),
}
return nil
}

func (p *Peer) parseAllowedIPs(v string) error {
var (
ai *net.IPNet
kv []string
err error
i int
ip, ip4 net.IP
)

kv = strings.Split(v, ",")
for i = range kv {
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
if err != nil {
return err
}
if ip4 = ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
}
ai.IP = ip
p.AllowedIPs = append(p.AllowedIPs, ai)
}
return nil
}

// ParseDump parses a given WireGuard dump and produces a Conf struct.
func ParseDump(buf []byte) (*Conf, error) {
// from man wg, show section:
// If dump is specified, then several lines are printed;
// the first contains in order separated by tab: private-key, public-key, listen-port, fw‐mark.
// Subsequent lines are printed for each peer and contain in order separated by tab:
// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
var (
active section
values []string
c Conf
err error
iface *Interface
peer *Peer
port uint64
sec int64
pka int
line int
)
// First line is Interface
active = interfaceSection
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
values = strings.Split(s.Text(), dumpSeparator)

switch active {
case interfaceSection:
if len(values) < dumpInterfaceLen {
return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen)
}
iface = new(Interface)
for i := range values {
switch i {
case dumpInterfacePrivateKeyIndex:
iface.PrivateKey = []byte(values[i])
case dumpInterfaceListenPortIndex:
port, err = strconv.ParseUint(values[i], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err)
}
iface.ListenPort = uint32(port)
}
}
c.Interface = iface
// Next lines are Peers
active = peerSection
case peerSection:
if len(values) < dumpPeerLen {
return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen)
}
peer = new(Peer)

for i := range values {
switch i {
case dumpPeerPublicKeyIndex:
peer.PublicKey = []byte(values[i])
case dumpPeerPresharedKeyIndex:
if values[i] == dumpNone {
continue
}
peer.PresharedKey = []byte(values[i])
case dumpPeerEndpointIndex:
if values[i] == dumpNone {
continue
}
err = peer.parseEndpoint(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err)
}
case dumpPeerAllowedIPsIndex:
if values[i] == dumpNone {
continue
}
err = peer.parseAllowedIPs(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err)
}
case dumpPeerLatestHandshakeIndex:
if values[i] == "0" {
// Use go zero value, not unix 0 timestamp.
peer.LatestHandshake = time.Time{}
continue
}
sec, err = strconv.ParseInt(values[i], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err)
}
peer.LatestHandshake = time.Unix(sec, 0)
case dumpPeerPersistentKeepaliveIndex:
if values[i] == dumpOff {
continue
}
pka, err = strconv.Atoi(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err)
}
peer.PersistentKeepalive = pka
}
}
c.Peers = append(c.Peers, peer)
peer = nil
}
line++
}
return &c, nil
}
Loading

0 comments on commit e12b502

Please sign in to comment.