diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index ee018954..a805c1e8 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -454,13 +454,18 @@ func (m *Mesh) applyTopology() { return } // Find the old configuration. - oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name) + oldConfDump, err := wireguard.ShowDump(link.Attrs().Name) + if err != nil { + level.Error(m.logger).Log("error", err) + m.errorCounter.WithLabelValues("apply").Inc() + return + } + oldConf, err := wireguard.ParseDump(oldConfDump) if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() return } - oldConf := wireguard.Parse(oldConfRaw) natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger) nodes[m.hostname].DiscoveredEndpoints = natEndpoints t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) @@ -782,17 +787,15 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf * } for _, n := range nodes { if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 { - level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false)) - // Should check location leader but only available in topology ... or have topology handle that list - // Better check wg latest-handshake - if !n.Endpoint.Equal(peer.Endpoint, false) { + level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake) + if (peer.LatestHandshake != time.Time{}) { natEndpoints[string(n.Key)] = peer.Endpoint } } } for _, p := range peers { if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 { - if !p.Endpoint.Equal(peer.Endpoint, false) { + if (peer.LatestHandshake != time.Time{}) { natEndpoints[string(p.PublicKey)] = peer.Endpoint } } diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 0ce55e34..7cef64d1 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -17,11 +17,13 @@ package wireguard import ( "bufio" "bytes" + "errors" "fmt" "net" "sort" "strconv" "strings" + "time" "k8s.io/apimachinery/pkg/util/validation" ) @@ -31,6 +33,9 @@ type key string const ( separator = "=" + dumpSeparator = "\t" + dumpNone = "(none)" + dumpOff = "off" interfaceSection section = "Interface" peerSection section = "Peer" listenPortKey key = "ListenPort" @@ -42,6 +47,30 @@ const ( publicKeyKey key = "PublicKey" ) +type dumpInterfaceIndex int + +const ( + dumpInterfacePrivateKeyIndex = iota + dumpInterfacePublicKeyIndex + dumpInterfaceListenPortIndex + dumpInterfaceFWMarkIndex + dumpInterfaceLen +) + +type dumpPeerIndex int + +const ( + dumpPeerPublicKeyIndex = iota + dumpPeerPresharedKeyIndex + dumpPeerEndpointIndex + dumpPeerAllowedIPsIndex + dumpPeerLatestHandshakeIndex + dumpPeerTransferRXIndex + dumpPeerTransferTXIndex + dumpPeerPersistentKeepaliveIndex + dumpPeerLen +) + // Conf represents a WireGuard configuration file. type Conf struct { Interface *Interface @@ -61,6 +90,8 @@ type Peer struct { PersistentKeepalive int PresharedKey []byte PublicKey []byte + // The following fields are part of the runtime information, not the configuration. + LatestHandshake time.Time } // DeduplicateIPs eliminates duplicate allowed IPs. @@ -146,13 +177,11 @@ func (d DNSOrIP) String() string { func Parse(buf []byte) *Conf { var ( active section - ai *net.IPNet kv []string c Conf err error iface *Interface i int - ip, ip4 net.IP k key line, v string peer *Peer @@ -205,49 +234,15 @@ func Parse(buf []byte) *Conf { case peerSection: switch k { case allowedIPsKey: - // Reuse string slice. - kv = strings.Split(v, ",") - for i = range kv { - ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i])) - if err != nil { - continue - } - if ip4 = ip.To4(); ip4 != nil { - ip = ip4 - } else { - ip = ip.To16() - } - ai.IP = ip - peer.AllowedIPs = append(peer.AllowedIPs, ai) - } - case endpointKey: - // Reuse string slice. - kv = strings.Split(v, ":") - if len(kv) < 2 { + err = peer.parseAllowedIPs(v) + if err != nil { continue } - port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32) + case endpointKey: + err = peer.parseEndpoint(v) if err != nil { continue } - d := DNSOrIP{} - ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]")) - if ip == nil { - if len(validation.IsDNS1123Subdomain(kv[0])) != 0 { - continue - } - d.DNS = kv[0] - } else { - if ip4 = ip.To4(); ip4 != nil { - d.IP = ip4 - } else { - d.IP = ip.To16() - } - } - peer.Endpoint = &Endpoint{ - DNSOrIP: d, - Port: uint32(port), - } case persistentKeepaliveKey: i, err = strconv.Atoi(v) if err != nil { @@ -448,3 +443,177 @@ func writeKey(buf *bytes.Buffer, k key) error { _, err = buf.WriteString(" = ") return err } + +var ( + errParseEndpoint = errors.New("could not parse Endpoint") +) + +func (p *Peer) parseEndpoint(v string) error { + var ( + kv []string + err error + ip, ip4 net.IP + port uint64 + ) + kv = strings.Split(v, ":") + if len(kv) < 2 { + return errParseEndpoint + } + port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32) + if err != nil { + return err + } + d := DNSOrIP{} + ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]")) + if ip == nil { + if len(validation.IsDNS1123Subdomain(kv[0])) != 0 { + return errParseEndpoint + } + d.DNS = kv[0] + } else { + if ip4 = ip.To4(); ip4 != nil { + d.IP = ip4 + } else { + d.IP = ip.To16() + } + } + + p.Endpoint = &Endpoint{ + DNSOrIP: d, + Port: uint32(port), + } + return nil +} + +func (p *Peer) parseAllowedIPs(v string) error { + var ( + ai *net.IPNet + kv []string + err error + i int + ip, ip4 net.IP + ) + + kv = strings.Split(v, ",") + for i = range kv { + ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i])) + if err != nil { + return err + } + if ip4 = ip.To4(); ip4 != nil { + ip = ip4 + } else { + ip = ip.To16() + } + ai.IP = ip + p.AllowedIPs = append(p.AllowedIPs, ai) + } + return nil +} + +// ParseDump parses a given WireGuard dump and produces a Conf struct. +func ParseDump(buf []byte) (*Conf, error) { + // from man wg, show section: + // If dump is specified, then several lines are printed; + // the first contains in order separated by tab: private-key, public-key, listen-port, fw‐mark. + // Subsequent lines are printed for each peer and contain in order separated by tab: + // public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive. + var ( + active section + values []string + c Conf + err error + iface *Interface + peer *Peer + port uint64 + sec int64 + pka int + line int + ) + // First line is Interface + active = interfaceSection + s := bufio.NewScanner(bytes.NewBuffer(buf)) + for s.Scan() { + values = strings.Split(s.Text(), dumpSeparator) + + switch active { + case interfaceSection: + if len(values) < dumpInterfaceLen { + return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen) + } + iface = new(Interface) + for i := range values { + switch i { + case dumpInterfacePrivateKeyIndex: + iface.PrivateKey = []byte(values[i]) + case dumpInterfaceListenPortIndex: + port, err = strconv.ParseUint(values[i], 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err) + } + iface.ListenPort = uint32(port) + } + } + c.Interface = iface + // Next lines are Peers + active = peerSection + case peerSection: + if len(values) < dumpPeerLen { + return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen) + } + peer = new(Peer) + + for i := range values { + switch i { + case dumpPeerPublicKeyIndex: + peer.PublicKey = []byte(values[i]) + case dumpPeerPresharedKeyIndex: + if values[i] == dumpNone { + continue + } + peer.PresharedKey = []byte(values[i]) + case dumpPeerEndpointIndex: + if values[i] == dumpNone { + continue + } + err = peer.parseEndpoint(values[i]) + if err != nil { + return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err) + } + case dumpPeerAllowedIPsIndex: + if values[i] == dumpNone { + continue + } + err = peer.parseAllowedIPs(values[i]) + if err != nil { + return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err) + } + case dumpPeerLatestHandshakeIndex: + if values[i] == "0" { + // Use go zero value, not unix 0 timestamp. + peer.LatestHandshake = time.Time{} + continue + } + sec, err = strconv.ParseInt(values[i], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err) + } + peer.LatestHandshake = time.Unix(sec, 0) + case dumpPeerPersistentKeepaliveIndex: + if values[i] == dumpOff { + continue + } + pka, err = strconv.Atoi(values[i]) + if err != nil { + return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err) + } + peer.PersistentKeepalive = pka + } + } + c.Peers = append(c.Peers, peer) + peer = nil + } + line++ + } + return &c, nil +} diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go index 81a3ed05..aa3e782d 100644 --- a/pkg/wireguard/conf_test.go +++ b/pkg/wireguard/conf_test.go @@ -17,6 +17,8 @@ package wireguard import ( "net" "testing" + + "github.com/kylelemons/godebug/pretty" ) func TestCompareConf(t *testing.T) { @@ -308,3 +310,47 @@ func TestCompareEndpoint(t *testing.T) { } } } + +func TestCompareDumpConf(t *testing.T) { + for _, tc := range []struct { + name string + d []byte + c []byte + }{ + { + name: "empty", + d: []byte{}, + c: []byte{}, + }, + { + name: "redacted copy from wg output", + d: []byte(`private B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4= 51820 off +key1 (none) 10.254.1.1:51820 100.64.1.0/24,192.168.0.125/32,10.4.0.1/32 1619012801 67048 34952 10 +key2 (none) 10.254.2.1:51820 100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32 1619013058 1134456 10077852 10`), + c: []byte(`[Interface] + ListenPort = 51820 + PrivateKey = private + + [Peer] + PublicKey = key1 + AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32 + Endpoint = 10.254.1.1:51820 + PersistentKeepalive = 10 + + [Peer] + PublicKey = key2 + AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32 + Endpoint = 10.254.2.1:51820 + PersistentKeepalive = 10`), + }, + } { + + dumpConf, _ := ParseDump(tc.d) + conf := Parse(tc.c) + // Equal will ignore runtime fields and only compare configuration fields. + if !dumpConf.Equal(conf) { + diff := pretty.Compare(dumpConf, conf) + t.Errorf("test case %q: got diff: %v", tc.name, diff) + } + } +} diff --git a/pkg/wireguard/wireguard.go b/pkg/wireguard/wireguard.go index dbbba72d..953eb9f0 100644 --- a/pkg/wireguard/wireguard.go +++ b/pkg/wireguard/wireguard.go @@ -119,3 +119,15 @@ func ShowConf(iface string) ([]byte, error) { } return stdout.Bytes(), nil } + +// ShowDump gets the WireGuard configuration and runtime information for the given interface. +func ShowDump(iface string) ([]byte, error) { + cmd := exec.Command("wg", "show", iface, "dump") + var stderr, stdout bytes.Buffer + cmd.Stderr = &stderr + cmd.Stdout = &stdout + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String()) + } + return stdout.Bytes(), nil +}