Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions infrastructure/containers/vpn/makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
install:
pip install -e .
1 change: 1 addition & 0 deletions infrastructure/containers/vpn/rootski_vpn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Generate the server config file for the Rootski Vpn."""
217 changes: 217 additions & 0 deletions infrastructure/containers/vpn/rootski_vpn/generate_server_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
"""
Standalone shell script for generating the secrets and configuration file for the wireguard server.

This script generates the necessary files (wg0.conf, server.key, and server.pub) to run a wireguard server
on a lightsail instance. This file is used by wireguard_vpn.py in the
rootski/infrastructure/iac/aws-cdk/backend/rootski_backend_cdk/lightsail_vpn folder.
"""

import json
from pathlib import Path
from textwrap import dedent
from typing import Dict, List

import boto3


def main():
"""Genereate files for the server configuration, server private key, and server public key."""
key_pairs: List[dict] = fetch_key_pairs_from_ssm(
key_pairs_ssm_prefix="/rootski/wireguard-vpn/key-pair-for-ip", aws_region="us-west-2"
)
generate_wireguard_server_configuration_file(
key_pairs=key_pairs,
server_conf_fpath=Path("/etc/wireguard/wg0.conf"),
)
generate_wireguard_server_private_key_file(
key_pairs=key_pairs,
server_key_fpath=Path("/etc/wireguard/server.key"),
)
generate_wireguard_server_public_key_file(
key_pairs=key_pairs,
server_pub_fpath=Path("/etc/wireguard/server.pub"),
)


def fetch_key_pairs_from_ssm(
key_pairs_ssm_prefix: str,
aws_region: str,
) -> List[dict]:
"""
Retrieve the deserialized VPN key-pair data from ssm.

:param key_pairs_ssm_prefix: The hierarchy for the parameter, which is the parameter name except for
the last part of the parameter. See :py:func:``fetch_all_ssm_parameters_with_prefix``
:param aws_region: The AWS region where the key-pairs are stored in ssm
:return: The deserialized VPN key-pair data from ssm. See footski_vpn.wireguard_keygen_utils
"""
key_pair_params: List[dict] = fetch_all_ssm_parameters_with_prefix(
path_prefix=key_pairs_ssm_prefix,
aws_region=aws_region,
)
key_pairs: List[dict] = deserialize_key_pair_ssm_parameters(key_pair_parameters=key_pair_params)

return key_pairs


def generate_wireguard_server_configuration_file(
key_pairs: List[dict],
server_conf_fpath: Path = Path("/etc/wireguard/wg0.conf"),
):
"""
Generate the wireguard server configruation file.

:param key_pairs: List of key-pair data
:param server_conf_fpath: Path object representing the filepath to write the server configuration to
"""
key_pairs_sorted_by_ip_address: List[dict] = sort_key_pairs_by_ip_address(key_pairs=key_pairs)

# generate the text for the [Interface] section
server_key_pair = key_pairs_sorted_by_ip_address[0]
interface_section = generate_wireguard_interface(server_private_key=server_key_pair["private_key"])

# generate the text for each [Peer] section
peer_sections = [
generate_wireguard_peer(
peer_public_key=key_pair["public_key"],
ip_address=key_pair["private_ip_address_on_vpn_network"],
owner_username=key_pair["owner_name"],
)
for key_pair in key_pairs_sorted_by_ip_address[1:]
]

# combine the sections into a single string
sections: List[str] = [interface_section, *peer_sections]
server_conf_contents: str = "\n\n".join(sections)

# write the contents to the server configuration file on disk
server_conf_fpath.write_text(server_conf_contents, encoding="utf-8")


def generate_wireguard_server_private_key_file(
key_pairs: List[dict],
server_key_fpath: Path = Path("/etc/wireguard/server.key"),
):
"""
Generate the wireguard server.key file.

:param key_pairs: List of key-pair data
:param server_conf_fpath: Path object representing the filepath to write the server.key file to
"""
server_key_fpath.write_text(key_pairs[0]["private_key"], encoding="utf8")


def generate_wireguard_server_public_key_file(
key_pairs: List[dict],
server_pub_fpath: Path = Path("/etc/wireguard/server.pub"),
):
"""
Generate the wireguard server.pub file.

:param key_pairs: List of key-pair data
:param server_conf_fpath: Path object representing the filepath to write the server.pub file to
"""
server_pub_fpath.write_text(key_pairs[0]["public_key"], encoding="utf8")


def generate_wireguard_interface(server_private_key: str) -> str:
"""
Generate the interface for a wireguard server configuration file wg0.conf.

:param server_private_key: The private key of the wireguard VPN server
:return interface: A string representing the [Interface] section of the wireguard configuration file
"""
interface = dedent(
f"""\
[Interface]
Address = 10.0.0.1/24
ListenPort = 51820
PrivateKey = {server_private_key}
PostUp = iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE
PostDown = iptables -t nat -D POSTROUTING -o eth0 -j MASQUERADE
"""
)

return interface


def generate_wireguard_peer(peer_public_key: str, ip_address: str, owner_username: str) -> str:
"""
Generate a peer for a wireguard server configuration file wg0.conf.

:param peer_public_key: The public key for a peer's rsa key-pair
:return peer: A string representing a [Peer] section of the wireguard configuration file
"""
peer = dedent(
f"""\
[Peer]
# Username = {owner_username}
PublicKey = {peer_public_key}
AllowedIPs = {ip_address}/32
"""
)

return peer


def fetch_all_ssm_parameters_with_prefix(path_prefix: str, aws_region: str) -> List[dict]:
"""
Retrieve all key-pair aws ssm pararmters using a given prefix.

The aws parameters are retrieved using the boto3 SSM client, and the structure of the key_pair_parameters are found in the following URL.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#SSM.Client.get_parameters_by_path

:param path_prefix: The hierarchy for the parameter, which is the parameter name except for
the last part of the parameter
:param aws_region: The aws region where the key-pair data is stored in ssm
:return: The aws key-pair ssm paramters
"""
# prepare an iterator to fetch the key pairs in batches
ssm_client = boto3.client("ssm", region_name=aws_region)
paginator = ssm_client.get_paginator("get_parameters_by_path")
response_iterator = paginator.paginate(Path=path_prefix)

# fetch all pages of parameters from SSM
pages: List[str, List[dict]] = [page["Parameters"] for page in response_iterator]
# flatten the list of lists to a single list of parameter dicts
parameters: List[dict] = sum(pages, [])

return parameters


def deserialize_key_pair_ssm_parameters(key_pair_parameters: List[str]) -> List[Dict[str, str]]:
"""
Deserialize the aws key-pair ssm parameters.

:param key_pair_parameters: The key-pair aws pararmeters.
See :py:func:``fetch_all_ssm_parameters_with_prefix``
:return: The VPN key-pair data. See rootski_vpn.wireguard_keygen_utils.
"""
deserialized_key_pairs: List[Dict[str, str]] = [
json.loads(ssm_parameter["Value"]) for ssm_parameter in key_pair_parameters
]

return deserialized_key_pairs


def sort_key_pairs_by_ip_address(key_pairs: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
Sort the key pairs by IP address so the keys are sequentially ordered.

:param key_pairs: The VPN key-pair data to be sorted by ip-address
:return sorted_key_pairs: The sorted key-pair data sequentially organized
"""

def by_ip_address(key_pair_data: dict) -> List[int]:
"""Return the IP address of a key pair in a form that is sortable."""
numerical_representation_of_ip: List[int] = [
int(bit) for bit in key_pair_data["private_ip_address_on_vpn_network"].split(".")
]
return numerical_representation_of_ip

sorted_key_pairs: List[Dict[str, str]] = sorted(key_pairs, key=by_ip_address)
return sorted_key_pairs


if __name__ == "__main__":
main()
61 changes: 61 additions & 0 deletions infrastructure/containers/vpn/rootski_vpn/store_keys_on_aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Generates a number of key-pair data and push them to AWS SSM.

This file will use the file wireguard_keygen_utils.py to generate the wireguard key-pair data
that will be pushed to aws ssm. These wireguard key-pairs will be used to set up the wireguard
vpn running on a lightsail instance. The keys created using this script will be used by
generate_server_conf.py
"""

import json
from typing import Dict, List

import boto3
from mypy_boto3_ssm import SSMClient
from wireguard_keygen_utils import generate_n_keypairs


def store_key_pair_on_aws(key_pair_identifier: str, keypair_data: dict, ssm_client: SSMClient):
"""
Store a keypair on AWS SSM.

:raises SSM.Client.exceptions.ParameterAlreadyExists: if the keypair already exists

:param key_pair_identifier: name that uniquely identifies
the keypair saved in parameter store. It is the last part
of the SSM parameter path.
:param keypair_data: dictionary containing the keypair data.
See :py:class:`wireguard_keygen_utils.VpnKeyPairData` for
an explanation of the contents.
"""
ssm_client.put_parameter(
Name=f"/rootski/wireguard-vpn/key-pair-for-ip/{key_pair_identifier}",
Description="WireGuard Key Data",
Value=json.dumps(keypair_data, indent=4),
Type="String",
Overwrite=False,
Tier="Standard",
DataType="text",
)


def store_n_key_pairs_on_aws(number_of_keypairs: int):
"""Generate and store a number of keypairs on AWS SSM."""
ssm_client = boto3.client("ssm", region_name="us-west-2")

rootski_wireguard_keypair_data_objs: List[Dict[str, str]] = generate_n_keypairs(
number_of_keys=number_of_keypairs
)

# store all keypairs in SSM
for wireguard_keypaird_data in rootski_wireguard_keypair_data_objs:
ip_address = wireguard_keypaird_data["private_ip_address_on_vpn_network"]
store_key_pair_on_aws(
key_pair_identifier=ip_address, keypair_data=wireguard_keypaird_data, ssm_client=ssm_client
)


if __name__ == "__main__":

NUMBER_OF_KEY_PAIRS_TO_GENERATE = 15
store_n_key_pairs_on_aws(NUMBER_OF_KEY_PAIRS_TO_GENERATE)
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Generate a number of wireguard key-pair data using the ``pywgkey`` library.

See the docs: https://pywgkey.readthedocs.io/en/latest/
We use the ``pywgkey`` library to generate wireguard keys for Rootski serverices and contributors,
and put them into a VpnKeyPairData class. Then we will push our key-pair data to aws ssm using
the file store_keys_on_aws.py
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Type, TypedDict

from pywgkey import WgKey

#: default owner of a key pair if not already assigned to a rootski contributor
DEFAULT__VPN_IP_ADDRESS_OWNER_NAME = "unassigned"
#: max number of IP addresses on our wireguard VPN CIDR range
MAX_ALLOWED_WIREGUARD_KEY_PAIRS = 253


class NumberOfKeysError(Exception):
"""Exception for generating more keys than currently allowed."""


class VpnKeyPairData(TypedDict):
"""
Storing data for a wireguard Peer configuration.

There is an example of a peer configuration in the following documentation: https://www.wireguard.com/#cryptokey-routing
"""

#: RSA public key for a wireguard peer
public_key: str
#: RSA private key for a wireguard peer
private_key: str
#: contributor or rootski service that reserves this IP address
owner_name: str
#: IP address on the local wireguard VPN network reserved for this peer
private_ip_address_on_vpn_network: str


@dataclass
class KeyPair:
"""Wireguard key pair wrapper."""

public_key: str
private_key: str

@classmethod
def generate(cls: Type[KeyPair]) -> KeyPair:
"""Generate a KeyPair object."""
keypair = WgKey()
return cls(public_key=keypair.pubkey, private_key=keypair.privkey)

def to_dict(
self, ip_address: str, owner_name: Optional[str] = None, note: Optional[str] = None
) -> VpnKeyPairData:
"""Convert KeyPair keys to a VpnKeyPairData dictionary object."""
vpn_keypair_data = VpnKeyPairData(
public_key=self.public_key,
private_key=self.private_key,
owner_name=owner_name or DEFAULT__VPN_IP_ADDRESS_OWNER_NAME,
private_ip_address_on_vpn_network=ip_address,
note=note,
)
return vpn_keypair_data


def generate_n_keypairs(number_of_keys: int, num_reserved_keys: int = 10) -> List[VpnKeyPairData]:
"""
Generate ``num_keys`` wireguard keypairs.

:param num_keys: number of keys to generate
:param num_reserved_keys: reserve certain IP addresses for rootski
services e.g. vpn.rootski.io, mlflow.rootski.io, etc.

:return: a list containing ``num_keys`` :py:class:`KeyPair` object
"""
if number_of_keys > MAX_ALLOWED_WIREGUARD_KEY_PAIRS:
raise NumberOfKeysError("num_keys must be smaller than 254")

# create key pairs for IP addresses reserved for rootski services
reserved_key_pairs = [
KeyPair.generate().to_dict(
ip_address=f"10.0.0.{i+1}", note="This ip-address is reserved and not available for contributors."
)
for i in range(num_reserved_keys)
]

# create key pairs for IP addresses assignable to rootski contributors
unreserved_key_pairs = [
KeyPair.generate().to_dict(ip_address=f"10.0.0.{i+1}") for i in range(num_reserved_keys, number_of_keys)
]
key_pairs = reserved_key_pairs + unreserved_key_pairs

return key_pairs
15 changes: 15 additions & 0 deletions infrastructure/containers/vpn/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Package definition for rootski_vpn."""

import setuptools

setuptools.setup(
name="rootski_vpn",
version="0.0.1",
description="Generate WireGuard Vpn configuration files",
long_description_content_type="text/markdown",
author="rootski-io",
package_dir={"": "rootski_vpn"},
packages=setuptools.find_packages(where="rootski_vpn"),
install_requires=["pywgkey==1.0.0", "boto3", "boto3-stubs[ssm]"],
python_requires=">=3.7",
)