Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 183 additions & 41 deletions src/opnsense/scripts/wireguard/reresolve-dns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/local/bin/python3
#!/usr/bin/env python3

"""
Copyright (c) 2023 Ad Schellevis <ad@opnsense.org>
Copyright (c) 2023-2026 Ad Schellevis <ad@opnsense.org>
All rights reserved.

Redistribution and use in source and binary forms, with or without
Expand All @@ -27,49 +27,191 @@
"""
# Python implementation to re-resolve dns entries, for reference see:
# https://github.com/WireGuard/wireguard-tools/tree/master/contrib/reresolve-dns

from typing import Tuple, Union, List
import sys
import glob
import os
import time
import subprocess
import logging
from logging.handlers import RotatingFileHandler
import argparse


def create_logger(log_file: str) -> logging.Logger:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
RotatingFileHandler(
log_file, encoding="utf-8", maxBytes=10240, backupCount=4
),
logging.StreamHandler(),
],
)
return logging.getLogger()


def runner(cmd: Union[List[str], str]) -> Tuple[bool, str]:
try:
logger.debug("Running command: {}".format(cmd))
child = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
)
stdout, stderr = child.communicate(timeout=60)
child.wait(timeout=60)
if child.returncode == 0:
return True, stdout
logger.error(
"Command {} failed with exit code {}:\n{}".format(
cmd, child.returncode, stderr
)
)
except subprocess.TimeoutExpired as exc:
logger.error("Command {} took too long: {}".format(cmd, exc))
except subprocess.SubprocessError as exc:
logger.error("Command {} failed: {}".format(cmd, exc))
return False, None


def get_handshakes() -> dict:
logger.debug("Getting handshakes")
# wg show all latest-handshakes produces one line per peer in for of
# iface pubkey epoch-of-last-handshake

handshakes = {}
ts_now = time.time()

result, content = runner(["/usr/bin/wg", "show", "all", "latest-handshakes"])
if not result:
return handshakes

for line in content.split("\n"):
parts = line.split()
if len(parts) == 3 and parts[2].isdigit():
elapsed_time = ts_now - int(parts[2])
handshakes["%s-%s" % (parts[0], parts[1])] = elapsed_time
logger.info(
"Last handshake for interface {} was {:.2f} seconds ago".format(
parts[0], elapsed_time
)
)
return handshakes


def check_recent_handshakes(threshold: int, conf_file_path: str) -> bool:
successful_run = True
handshakes = get_handshakes()
globs = glob.glob(conf_file_path.rstrip("/") + "/*.conf")
if not globs:
logger.warning(
"It seems that there are no config file candidates in {}".format(
conf_file_path
)
)
return False
for filename in globs:
this_peer = {}
ifname = os.path.basename(filename).split(".")[0]
logger.info("Checking handshake threshold for interface {}".format(ifname))
with open(filename, "r", encoding="utf-8") as fhandle:
for line in fhandle:
if line.startswith("[Peer]"):
this_peer = {"ifname": ifname}
elif line.lower().startswith("publickey"):
this_peer["PublicKey"] = line.split("=", 1)[1].strip()
elif line.lower().startswith("endpoint"):
this_peer["Endpoint"] = line.split("=", 1)[1].strip()

if "Endpoint" in this_peer and "PublicKey" in this_peer:
peer_key = "%(ifname)s-%(PublicKey)s" % this_peer
if handshakes.get(peer_key, 999) > threshold:
logger.info(
"Trying to reset connection to peer {}".format(
this_peer["Endpoint"]
)
)
# skip if there has been a handshake recently
result, _ = runner(
[
"/usr/bin/wg",
"set",
ifname,
"peer",
this_peer["PublicKey"],
"endpoint",
this_peer["Endpoint"],
],
)
if not result:
logger.error(
"Failed to reset peer {} on interface {}".format(
this_peer["Endpoint"], ifname
)
)
successful_run = False
this_peer = {}
return successful_run


if __name__ == "__main__":
default_logfile = "/var/log/{}.log".format(os.path.basename(__file__))
threshold = 135
configdir = "/usr/local/etc/wireguard"

parser = argparse.ArgumentParser(
prog=__file__, description="DNS Watchguard script for Wireguard"
)

parser.add_argument(
"-t",
"--threshold",
type=int,
dest="threshold",
default=None,
required=False,
help="Max seconds allowed before retriggering a wireguard reload, defaults to {} seconds".format(
threshold
),
)

parser.add_argument(
"-c",
"--configdir",
type=str,
dest="configdir",
default=None,
required=False,
help="Path to wireguard configuration directory, defaults {}".format(configdir),
)

parser.add_argument(
"--logfile",
type=str,
dest="logfile",
default=None,
required=False,
help="Path to logfile, defaults to {}".format(default_logfile),
)
args = parser.parse_args()

if args.threshold:
threshold = args.threshold
if args.configdir:
configdir = args.configdir

if args.logfile:
logger = create_logger(args.logfile)
else:
logger = create_logger(default_logfile)

logger.info(
"Running wireguard watchdog with a threshhold of {} seconds".format(threshold)
)

sp = subprocess.run(['/usr/bin/wg', 'show', 'all', 'latest-handshakes'], capture_output=True, text=True)
ts_now = time.time()
handshakes = {}
for line in sp.stdout.split('\n'):
parts = line.split()
if len(parts) == 3 and parts[2].isdigit():
handshakes["%s-%s" % (parts[0], parts[1])] = ts_now - int(parts[2])


for filename in glob.glob('/usr/local/etc/wireguard/*.conf'):
this_peer = {}
ifname = os.path.basename(filename).split('.')[0]
with open(filename, 'r') as fhandle:
for line in fhandle:
if line.startswith('[Peer]'):
this_peer = {'ifname': ifname}
elif line.startswith('PublicKey'):
this_peer['PublicKey'] = line.split('=', 1)[1].strip()
elif line.startswith('Endpoint'):
this_peer['Endpoint'] = line.split('=', 1)[1].strip()

if 'Endpoint' in this_peer and 'PublicKey' in this_peer:
peer_key = "%(ifname)s-%(PublicKey)s" % this_peer
if handshakes.get(peer_key, 999) > 135:
# skip if there has been a handshake recently
subprocess.run(
[
'/usr/bin/wg',
'set',
ifname,
'peer',
this_peer['PublicKey'],
'endpoint',
this_peer['Endpoint']
],
capture_output=True,
text=True
)
this_peer = {}
try:
sys.exit(0) if check_recent_handshakes(threshold, configdir) else sys.exit(1)
except Exception as exc:
logger.critical("Failed to run: {}".format(exc))
sys.exit(1)