Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openfl-docker/gramine_app/fx.manifest.template
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ sgx.allowed_files = [
"file:{{ workspace_root }}/plan/cols.yaml",
"file:{{ workspace_root }}/plan/data.yaml",
"file:{{ workspace_root }}/plan/plan.yaml",
"file:{{ workspace_root }}/attestation",
]
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ settings :
last_state_path : save/last.pbuf
persist_checkpoint: True
persistent_db_path: local_state/tensor.db
enable_remote_attestation : False
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/collaborator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ settings :
opt_treatment : 'CONTINUE_LOCAL'
use_delta_updates : True
db_store_rounds : 1
enable_remote_attestation : False
2 changes: 2 additions & 0 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
persist_checkpoint=True,
persistent_db_path=None,
secure_aggregation=False,
enable_remote_attestation=False,
):
"""Initializes the Aggregator.

Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
self.connector = connector
self.enable_remote_attestation = enable_remote_attestation

self.quit_job_sent_to = []

Expand Down
239 changes: 239 additions & 0 deletions openfl/cryptography/signer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import base64
import os
from datetime import datetime, timedelta

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509.oid import NameOID


class ECDSASigner:
"""ECDSA secp384R1 signer

This class implements ECDSA methods specific to a single instance of
enclave, so, that it can be reused across without passing references
across the breadth of the code. The relevant keys and certificates are
stored in /tmp. /tmp should be appropriately configured in the enclave
manifest, so, that it is wiped off when the enclave exits.

Raises:
Exception: ValueError() for incorrect configuration

Returns:
object: The reference of the only object
"""

__signer_instance = None

@staticmethod
def get_instance(privkey_path="/tmp/client_privkey.pem"):
if ECDSASigner.__signer_instance is None:
ECDSASigner(privkey_path)
return ECDSASigner.__signer_instance

def __init__(self, privkey_path="/tmp/client_privkey.pem", cert_path="/tmp/"):
"""Constructor for creating the ECDSA Certificate Chain

Args:
privkey_path (string, optional): Path to the existing client private key.
Defaults to None.

Raises:
Exception: Generic, in case there's invalid configuration or file data
"""
if ECDSASigner.__signer_instance is not None:
raise Exception("ECDSASigner: Only one instance allowed")
else:
ECDSASigner.__signer_instance = self

self._root_cert_path = os.path.join(cert_path, "openfl-security-ca-cert.pem")
self._client_cert_path = os.path.join(cert_path, "openfl-enclave-cert.pem")

# If the private key already exists, then reuse to create the certificate
# if doesn't exist.
self._client_cert = None
self._root_cert = None

if privkey_path and os.path.exists(privkey_path):
with open(privkey_path, "rb") as client_priv_fh:
client_privkey_pem = client_priv_fh.read()

# Once the private key is found in filesystem, then look for saved
# certificate and the corresponding public key
self._client_privkey = load_pem_private_key(client_privkey_pem, password=None)
if not isinstance(self._client_privkey, ec.EllipticCurvePrivateKey):
raise ValueError(f"Invalid private key format: '{privkey_path}'")

self._client_pubkey = self._client_privkey.public_key()

# Check if certificate is already present in the filesystem, if not then
# serialize is not called yet
if os.path.exists(self._client_cert_path):
with open(self._client_cert_path, "rb") as cert_fh:
self._client_cert = x509.load_pem_x509_certificate(
cert_fh.read(), default_backend()
)

# FIXME: Post upgrading cryptography module, delete this below
# and change above call to load_pem_x509_certificate(s) to load the chain
if not os.path.exists(self._root_cert_path):
raise ValueError(
"Out of tree modification detected, "
"clean all the keys and certs and try again"
)
with open(self._root_cert_path, "rb") as cert_fh:
self._root_cert = x509.load_pem_x509_certificate(
cert_fh.read(), default_backend()
)

else:
self._root_privkey = ec.generate_private_key(ec.SECP384R1(), default_backend())
self._root_pubkey = self._root_privkey.public_key()

self._client_privkey = ec.generate_private_key(ec.SECP384R1(), default_backend())

self._client_pubkey = self._client_privkey.public_key()

def __get_cert(
self,
subject_name,
subject_pubkey,
issuer_name,
issuer_privkey,
ca=False,
mrenclave_data=None,
):
"""Create a certificate with optional MRENCLAVE_OID extension.

Args:
subject_name (string): The subject name in the certificate, must match the URL.
subject_pubkey (string): Subject's public key to be embedded in the certificate.
issuer_name (string): The CA name.
issuer_privkey (string): The CA private key to sign the subject certificate.
ca (bool, optional): To set CA=true property. Defaults to False.
mrenclave_data (string, optional): The mrenclave data to be added as a custom extension.

Returns:
object: Certificate.
"""
# Create the certificate builder
cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, subject_name)])
)
cert_builder = cert_builder.issuer_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, issuer_name)])
)

oneday = timedelta(1, 0, 0)
start_validity = datetime.today() - oneday
end_validity = datetime.today() + (100 * oneday)
cert_builder = cert_builder.not_valid_before(start_validity)
cert_builder = cert_builder.not_valid_after(end_validity)

if ca:
cert_builder = cert_builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)

cert_builder = cert_builder.serial_number(x509.random_serial_number())
cert_builder = cert_builder.public_key(subject_pubkey)

# Add the custom MRENCLAVE_OID extension if mrenclave_data is provided
if mrenclave_data:
MRENCLAVE_OID = x509.ObjectIdentifier("1.3.6.1.4.1.99999.1.1") # Example OID
cert_builder = cert_builder.add_extension(
x509.UnrecognizedExtension(MRENCLAVE_OID, mrenclave_data.encode("utf-8")),
critical=False,
)

# Sign the certificate
return cert_builder.sign(issuer_privkey, algorithm=hashes.SHA384())

def __create_cert_chain(self):
root_cert_bytes = self._root_cert.public_bytes((serialization.Encoding.PEM))
client_cert_bytes = self._client_cert.public_bytes((serialization.Encoding.PEM))

# Concatenate together and return
cert_chain = client_cert_bytes.decode("utf-8") + root_cert_bytes.decode("utf-8")
return cert_chain

def cert(self, common_name, mrenclave_data=None):
"""Returns the self-signed certificate
Args:
common_name (string): to be used as subject and issuer name
Returns:
_type_: _description_
"""

# Create the certificate chaining with local CA
# A. Create a root CA certificate
# B. Create the node certificate

# Return the cached value if it exists
if self._client_cert:
return self.__create_cert_chain()

ca_common_name = f"{common_name}-CA"
self._root_cert = self.__get_cert(
ca_common_name, self._root_pubkey, ca_common_name, self._root_privkey, ca=True
)

self._client_cert = self.__get_cert(
common_name,
self._client_pubkey,
ca_common_name,
self._root_privkey,
False,
mrenclave_data=mrenclave_data,
)

return self.__create_cert_chain()

def get_pubkey(self):
"""returns public key as a PEM string"""

return self._client_pubkey.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("utf-8")

def sign(self, message):
"""sign message string using private key.

Return Value: base64 encoded string
"""

signature_bytes = self._client_privkey.sign(
message.encode("utf-8"), ec.ECDSA(hashes.SHA384())
)
return base64.b64encode(signature_bytes).decode("utf-8")

def serialize_private_key(self, filename="/tmp/client_privkey.pem", save_root_cert=True):
"""write the private key to a file in PEM format"""

client_privkey_pem = self._client_privkey.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

with open(filename, "wb") as fh:
fh.write(client_privkey_pem)

# Save the CA certificate if not already exists
if os.path.exists(self._root_cert_path) is False and save_root_cert:
root_cert_bytes = self._root_cert.public_bytes((serialization.Encoding.PEM))
with open(self._root_cert_path, "wb") as fh:
fh.write(root_cert_bytes)

# Save the client certificate if not already exists
if os.path.exists(self._client_cert_path) is False:
cert_chain = self.__create_cert_chain()
with open(self._client_cert_path, "wb") as fh:
fh.write(cert_chain.encode("utf-8"))

return client_privkey_pem
32 changes: 29 additions & 3 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Plan module."""

import os
from functools import partial
from hashlib import sha384
from importlib import import_module
Expand Down Expand Up @@ -625,6 +626,11 @@ def get_client_args(
certificate = f"cert/client/col_{common_name}.crt"
private_key = f"cert/client/col_{common_name}.key"

# added to test that new root certs generated by aggregator are used by clients
#if os.getenv("ROOT_CERT_PATH", None) is not None:
# cert_path = os.getenv("ROOT_CERT_PATH")
# root_certificate = cert_path + "/cert_chain.crt"

client_args = self.config["network"][SETTINGS]

# patch certificates
Expand All @@ -643,6 +649,7 @@ def get_server(
root_certificate=None,
private_key=None,
certificate=None,
attested_identity=None,
**kwargs,
):
"""Get gRPC or REST server of the aggregator instance.
Expand All @@ -659,7 +666,9 @@ def get_server(
Returns:
Aggregator Server: returns either gRPC or REST server of the aggregator instance.
"""
server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs)
server_args = self.get_server_args(
root_certificate, private_key, certificate, attested_identity, kwargs
)

server_args["aggregator"] = self.get_aggregator()
network_cfg = self.config["network"][SETTINGS]
Expand All @@ -679,14 +688,31 @@ def _get_server(self, protocol, **kwargs):
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
return server

def get_server_args(self, root_certificate, private_key, certificate, kwargs):
def get_server_args(
self, root_certificate, private_key, certificate, attested_identity, kwargs
):
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()

if not root_certificate or not private_key or not certificate:
if not root_certificate or not private_key or not certificate :
root_certificate = "cert/cert_chain.crt"
certificate = f"cert/server/agg_{common_name}.crt"
private_key = f"cert/server/agg_{common_name}.key"

if attested_identity:
cert_chain = os.path.join(
os.path.dirname(attested_identity.get_root_cert_path()),
"cert_chain.crt",
)
logger.info("building root cert: %s", cert_chain)
attested_identity.build_root_cert("cert/client/")
attested_identity.save_cert(f"cert/server/agg_{common_name}_self_signed.crt")
attested_identity.save_private_key(f"/tmp/agg_{common_name}.key")
root_certificate = cert_chain
certificate = f"cert/server/agg_{common_name}_self_signed.crt"
private_key = f"/tmp/agg_{common_name}.key"
else:
logger.info("Remote attestation is not enabled. Using default certificates.")

server_args = self.config["network"][SETTINGS]

# patch certificates
Expand Down
11 changes: 10 additions & 1 deletion openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openfl.federated import Plan
from openfl.interface.cli_helper import CERT_DIR
from openfl.utilities import click_types
from openfl.utilities.attestation import attestation_utils as attestation_utils
from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.utils import getfqdn_env

Expand Down Expand Up @@ -91,8 +92,16 @@ def start_(plan, authorized_cols, task_group):
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

# check if remote attestation is enabled
attested_identity = None
if parsed_plan.config["aggregator"]["settings"].get("enable_remote_attestation", False):
# check if the aggregator is running in a remote attestation environment
attested_identity = attestation_utils.get_remote_attestation("aggregator")
else:
logger.info("Remote attestation is not enabled.")

logger.info("🧿 Starting the Aggregator Service.")
server = parsed_plan.get_server()
server = parsed_plan.get_server(attested_identity=attested_identity)
server.serve()


Expand Down
4 changes: 3 additions & 1 deletion openfl/interface/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from openfl.federated import Plan
from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser
from openfl.interface.cli_helper import CERT_DIR
from openfl.utilities.attestation import attestation_utils as attestation_utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not this?

Suggested change
from openfl.utilities.attestation import attestation_utils as attestation_utils
from openfl.utilities.attestation import attestation_utils

from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.utils import rmtree

Expand Down Expand Up @@ -78,7 +79,8 @@ def start_(plan, collaborator_name, data_config):

# TODO: Need to restructure data loader config file loader
logger.info(f"Data paths: {plan_obj.cols_data_paths}")
echo(f"Data = {plan_obj.cols_data_paths}")
# this check is added to avoid mock objects failing

logger.info("🧿 Starting a Collaborator Service.")

collaborator = plan_obj.get_collaborator(collaborator_name)
Expand Down
Empty file.
Loading
Loading