diff --git a/openfl-docker/gramine_app/fx.manifest.template b/openfl-docker/gramine_app/fx.manifest.template index 55e20adf42..2c70eaa401 100755 --- a/openfl-docker/gramine_app/fx.manifest.template +++ b/openfl-docker/gramine_app/fx.manifest.template @@ -71,4 +71,5 @@ sgx.allowed_files = [ "file:{{ workspace_root }}/plan/cols.yaml", "file:{{ workspace_root }}/plan/data.yaml", "file:{{ workspace_root }}/plan/plan.yaml", + "file:{{ workspace_root }}/attestation", ] diff --git a/openfl-workspace/workspace/plan/defaults/aggregator.yaml b/openfl-workspace/workspace/plan/defaults/aggregator.yaml index e03f51f649..9172bdf276 100644 --- a/openfl-workspace/workspace/plan/defaults/aggregator.yaml +++ b/openfl-workspace/workspace/plan/defaults/aggregator.yaml @@ -6,3 +6,4 @@ settings : last_state_path : save/last.pbuf persist_checkpoint: True persistent_db_path: local_state/tensor.db + enable_remote_attestation : False diff --git a/openfl-workspace/workspace/plan/defaults/collaborator.yaml b/openfl-workspace/workspace/plan/defaults/collaborator.yaml index a119c1cbff..c017aac2ec 100644 --- a/openfl-workspace/workspace/plan/defaults/collaborator.yaml +++ b/openfl-workspace/workspace/plan/defaults/collaborator.yaml @@ -3,3 +3,4 @@ settings : opt_treatment : 'CONTINUE_LOCAL' use_delta_updates : True db_store_rounds : 1 + enable_remote_attestation : False diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index f1c07cf0f4..df4d5b8703 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -91,6 +91,7 @@ def __init__( persist_checkpoint=True, persistent_db_path=None, secure_aggregation=False, + enable_remote_attestation=False, ): """Initializes the Aggregator. @@ -146,6 +147,7 @@ def __init__( self.uuid = aggregator_uuid self.federation_uuid = federation_uuid self.connector = connector + self.enable_remote_attestation = enable_remote_attestation self.quit_job_sent_to = [] diff --git a/openfl/cryptography/signer.py b/openfl/cryptography/signer.py new file mode 100644 index 0000000000..8e0df321fa --- /dev/null +++ b/openfl/cryptography/signer.py @@ -0,0 +1,239 @@ +import base64 +import os +from datetime import datetime, timedelta + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.x509.oid import NameOID + + +class ECDSASigner: + """ECDSA secp384R1 signer + + This class implements ECDSA methods specific to a single instance of + enclave, so, that it can be reused across without passing references + across the breadth of the code. The relevant keys and certificates are + stored in /tmp. /tmp should be appropriately configured in the enclave + manifest, so, that it is wiped off when the enclave exits. + + Raises: + Exception: ValueError() for incorrect configuration + + Returns: + object: The reference of the only object + """ + + __signer_instance = None + + @staticmethod + def get_instance(privkey_path="/tmp/client_privkey.pem"): + if ECDSASigner.__signer_instance is None: + ECDSASigner(privkey_path) + return ECDSASigner.__signer_instance + + def __init__(self, privkey_path="/tmp/client_privkey.pem", cert_path="/tmp/"): + """Constructor for creating the ECDSA Certificate Chain + + Args: + privkey_path (string, optional): Path to the existing client private key. + Defaults to None. + + Raises: + Exception: Generic, in case there's invalid configuration or file data + """ + if ECDSASigner.__signer_instance is not None: + raise Exception("ECDSASigner: Only one instance allowed") + else: + ECDSASigner.__signer_instance = self + + self._root_cert_path = os.path.join(cert_path, "openfl-security-ca-cert.pem") + self._client_cert_path = os.path.join(cert_path, "openfl-enclave-cert.pem") + + # If the private key already exists, then reuse to create the certificate + # if doesn't exist. + self._client_cert = None + self._root_cert = None + + if privkey_path and os.path.exists(privkey_path): + with open(privkey_path, "rb") as client_priv_fh: + client_privkey_pem = client_priv_fh.read() + + # Once the private key is found in filesystem, then look for saved + # certificate and the corresponding public key + self._client_privkey = load_pem_private_key(client_privkey_pem, password=None) + if not isinstance(self._client_privkey, ec.EllipticCurvePrivateKey): + raise ValueError(f"Invalid private key format: '{privkey_path}'") + + self._client_pubkey = self._client_privkey.public_key() + + # Check if certificate is already present in the filesystem, if not then + # serialize is not called yet + if os.path.exists(self._client_cert_path): + with open(self._client_cert_path, "rb") as cert_fh: + self._client_cert = x509.load_pem_x509_certificate( + cert_fh.read(), default_backend() + ) + + # FIXME: Post upgrading cryptography module, delete this below + # and change above call to load_pem_x509_certificate(s) to load the chain + if not os.path.exists(self._root_cert_path): + raise ValueError( + "Out of tree modification detected, " + "clean all the keys and certs and try again" + ) + with open(self._root_cert_path, "rb") as cert_fh: + self._root_cert = x509.load_pem_x509_certificate( + cert_fh.read(), default_backend() + ) + + else: + self._root_privkey = ec.generate_private_key(ec.SECP384R1(), default_backend()) + self._root_pubkey = self._root_privkey.public_key() + + self._client_privkey = ec.generate_private_key(ec.SECP384R1(), default_backend()) + + self._client_pubkey = self._client_privkey.public_key() + + def __get_cert( + self, + subject_name, + subject_pubkey, + issuer_name, + issuer_privkey, + ca=False, + mrenclave_data=None, + ): + """Create a certificate with optional MRENCLAVE_OID extension. + + Args: + subject_name (string): The subject name in the certificate, must match the URL. + subject_pubkey (string): Subject's public key to be embedded in the certificate. + issuer_name (string): The CA name. + issuer_privkey (string): The CA private key to sign the subject certificate. + ca (bool, optional): To set CA=true property. Defaults to False. + mrenclave_data (string, optional): The mrenclave data to be added as a custom extension. + + Returns: + object: Certificate. + """ + # Create the certificate builder + cert_builder = x509.CertificateBuilder() + cert_builder = cert_builder.subject_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, subject_name)]) + ) + cert_builder = cert_builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, issuer_name)]) + ) + + oneday = timedelta(1, 0, 0) + start_validity = datetime.today() - oneday + end_validity = datetime.today() + (100 * oneday) + cert_builder = cert_builder.not_valid_before(start_validity) + cert_builder = cert_builder.not_valid_after(end_validity) + + if ca: + cert_builder = cert_builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + + cert_builder = cert_builder.serial_number(x509.random_serial_number()) + cert_builder = cert_builder.public_key(subject_pubkey) + + # Add the custom MRENCLAVE_OID extension if mrenclave_data is provided + if mrenclave_data: + MRENCLAVE_OID = x509.ObjectIdentifier("1.3.6.1.4.1.99999.1.1") # Example OID + cert_builder = cert_builder.add_extension( + x509.UnrecognizedExtension(MRENCLAVE_OID, mrenclave_data.encode("utf-8")), + critical=False, + ) + + # Sign the certificate + return cert_builder.sign(issuer_privkey, algorithm=hashes.SHA384()) + + def __create_cert_chain(self): + root_cert_bytes = self._root_cert.public_bytes((serialization.Encoding.PEM)) + client_cert_bytes = self._client_cert.public_bytes((serialization.Encoding.PEM)) + + # Concatenate together and return + cert_chain = client_cert_bytes.decode("utf-8") + root_cert_bytes.decode("utf-8") + return cert_chain + + def cert(self, common_name, mrenclave_data=None): + """Returns the self-signed certificate + Args: + common_name (string): to be used as subject and issuer name + Returns: + _type_: _description_ + """ + + # Create the certificate chaining with local CA + # A. Create a root CA certificate + # B. Create the node certificate + + # Return the cached value if it exists + if self._client_cert: + return self.__create_cert_chain() + + ca_common_name = f"{common_name}-CA" + self._root_cert = self.__get_cert( + ca_common_name, self._root_pubkey, ca_common_name, self._root_privkey, ca=True + ) + + self._client_cert = self.__get_cert( + common_name, + self._client_pubkey, + ca_common_name, + self._root_privkey, + False, + mrenclave_data=mrenclave_data, + ) + + return self.__create_cert_chain() + + def get_pubkey(self): + """returns public key as a PEM string""" + + return self._client_pubkey.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("utf-8") + + def sign(self, message): + """sign message string using private key. + + Return Value: base64 encoded string + """ + + signature_bytes = self._client_privkey.sign( + message.encode("utf-8"), ec.ECDSA(hashes.SHA384()) + ) + return base64.b64encode(signature_bytes).decode("utf-8") + + def serialize_private_key(self, filename="/tmp/client_privkey.pem", save_root_cert=True): + """write the private key to a file in PEM format""" + + client_privkey_pem = self._client_privkey.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + with open(filename, "wb") as fh: + fh.write(client_privkey_pem) + + # Save the CA certificate if not already exists + if os.path.exists(self._root_cert_path) is False and save_root_cert: + root_cert_bytes = self._root_cert.public_bytes((serialization.Encoding.PEM)) + with open(self._root_cert_path, "wb") as fh: + fh.write(root_cert_bytes) + + # Save the client certificate if not already exists + if os.path.exists(self._client_cert_path) is False: + cert_chain = self.__create_cert_chain() + with open(self._client_cert_path, "wb") as fh: + fh.write(cert_chain.encode("utf-8")) + + return client_privkey_pem diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 4d0459cbbc..00225e77a1 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -4,6 +4,7 @@ """Plan module.""" +import os from functools import partial from hashlib import sha384 from importlib import import_module @@ -625,6 +626,11 @@ def get_client_args( certificate = f"cert/client/col_{common_name}.crt" private_key = f"cert/client/col_{common_name}.key" + # added to test that new root certs generated by aggregator are used by clients + #if os.getenv("ROOT_CERT_PATH", None) is not None: + # cert_path = os.getenv("ROOT_CERT_PATH") + # root_certificate = cert_path + "/cert_chain.crt" + client_args = self.config["network"][SETTINGS] # patch certificates @@ -643,6 +649,7 @@ def get_server( root_certificate=None, private_key=None, certificate=None, + attested_identity=None, **kwargs, ): """Get gRPC or REST server of the aggregator instance. @@ -659,7 +666,9 @@ def get_server( Returns: Aggregator Server: returns either gRPC or REST server of the aggregator instance. """ - server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs) + server_args = self.get_server_args( + root_certificate, private_key, certificate, attested_identity, kwargs + ) server_args["aggregator"] = self.get_aggregator() network_cfg = self.config["network"][SETTINGS] @@ -679,14 +688,31 @@ def _get_server(self, protocol, **kwargs): raise ValueError(f"Unsupported transport_protocol '{protocol}'") return server - def get_server_args(self, root_certificate, private_key, certificate, kwargs): + def get_server_args( + self, root_certificate, private_key, certificate, attested_identity, kwargs + ): common_name = self.config["network"][SETTINGS]["agg_addr"].lower() - if not root_certificate or not private_key or not certificate: + if not root_certificate or not private_key or not certificate : root_certificate = "cert/cert_chain.crt" certificate = f"cert/server/agg_{common_name}.crt" private_key = f"cert/server/agg_{common_name}.key" + if attested_identity: + cert_chain = os.path.join( + os.path.dirname(attested_identity.get_root_cert_path()), + "cert_chain.crt", + ) + logger.info("building root cert: %s", cert_chain) + attested_identity.build_root_cert("cert/client/") + attested_identity.save_cert(f"cert/server/agg_{common_name}_self_signed.crt") + attested_identity.save_private_key(f"/tmp/agg_{common_name}.key") + root_certificate = cert_chain + certificate = f"cert/server/agg_{common_name}_self_signed.crt" + private_key = f"/tmp/agg_{common_name}.key" + else: + logger.info("Remote attestation is not enabled. Using default certificates.") + server_args = self.config["network"][SETTINGS] # patch certificates diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 16dc48e9bf..6d4a4ed36d 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -32,6 +32,7 @@ from openfl.federated import Plan from openfl.interface.cli_helper import CERT_DIR from openfl.utilities import click_types +from openfl.utilities.attestation import attestation_utils as attestation_utils from openfl.utilities.path_check import is_directory_traversal from openfl.utilities.utils import getfqdn_env @@ -91,8 +92,16 @@ def start_(plan, authorized_cols, task_group): parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group logger.info(f"Setting aggregator to assign: {task_group} task_group") + # check if remote attestation is enabled + attested_identity = None + if parsed_plan.config["aggregator"]["settings"].get("enable_remote_attestation", False): + # check if the aggregator is running in a remote attestation environment + attested_identity = attestation_utils.get_remote_attestation("aggregator") + else: + logger.info("Remote attestation is not enabled.") + logger.info("🧿 Starting the Aggregator Service.") - server = parsed_plan.get_server() + server = parsed_plan.get_server(attested_identity=attested_identity) server.serve() diff --git a/openfl/interface/collaborator.py b/openfl/interface/collaborator.py index ba0a84977c..8d70a8fff6 100644 --- a/openfl/interface/collaborator.py +++ b/openfl/interface/collaborator.py @@ -24,6 +24,7 @@ from openfl.federated import Plan from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser from openfl.interface.cli_helper import CERT_DIR +from openfl.utilities.attestation import attestation_utils as attestation_utils from openfl.utilities.path_check import is_directory_traversal from openfl.utilities.utils import rmtree @@ -78,7 +79,8 @@ def start_(plan, collaborator_name, data_config): # TODO: Need to restructure data loader config file loader logger.info(f"Data paths: {plan_obj.cols_data_paths}") - echo(f"Data = {plan_obj.cols_data_paths}") + # this check is added to avoid mock objects failing + logger.info("🧿 Starting a Collaborator Service.") collaborator = plan_obj.get_collaborator(collaborator_name) diff --git a/openfl/utilities/attestation/__init__.py b/openfl/utilities/attestation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfl/utilities/attestation/attestation_utils.py b/openfl/utilities/attestation/attestation_utils.py new file mode 100644 index 0000000000..651a6f7134 --- /dev/null +++ b/openfl/utilities/attestation/attestation_utils.py @@ -0,0 +1,481 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import hashlib +import http +import json +import logging +import os + +import requests + +from openfl.cryptography.signer import ECDSASigner + +logger = logging.getLogger(__name__) + + +class AttestedIdentity: + """Class to represent an attested identity for a participant enclave.""" + + def __init__(self, mrenclave, name, token, cert, private_key, public_key, root_cert_path): + """Initializes the AttestedIdentity with the provided parameters. + + Args: + mrenclave (str): MRENCLAVE value of the enclave. + name (str): Name of the enclave. + token (str): AVS report token. + cert (bytes): Certificate for mTLS communication. + private_key (bytes): Private key for mTLS communication. + """ + self.mrenclave = mrenclave + self.name = name + self.token = token + self.cert = cert + self.private_key = private_key + self.public_key = public_key + self.root_cert_path = root_cert_path + + # define getter/seetter methods for the attributes if needed + + def get_mrenclave(self): + """Returns the MRENCLAVE value of the enclave.""" + return self.mrenclave + + def get_name(self): + """Returns the name of the enclave.""" + return self.name + + def get_token(self): + """Returns the AVS report token.""" + return self.token + + def get_cert(self): + """Returns the certificate for mTLS communication.""" + return self.cert + + def get_private_key(self): + """Returns the private key for mTLS communication.""" + return self.private_key + + def set_mrenclave(self, mrenclave): + """Sets the MRENCLAVE value of the enclave.""" + self.mrenclave = mrenclave + + def get_passport(self): + """Returns a dictionary representation of the attested identity. + + Returns: + dict: A dictionary containing the MRENCLAVE value, participant name, + AVS report token, certificate, and private key. + """ + return { + "mrenclave": self.mrenclave, + "name": self.name, + "token": self.token, + "cert": self.cert, + } + + def save_passport(self, path): + """Saves the attested identity passport to a file. + + Args: + path (str): Path to save the passport file. + """ + if not os.path.exists(os.path.dirname(path)): + # give an error if the directory does not exist + raise FileNotFoundError(f"Directory {os.path.dirname(path)} does not exist") + + passport = self.get_passport() + with open(path, "w") as f: + json.dump(passport, f, indent=4) + + def save_token(self, path): + """Saves the AVS report token to a file. + + Args: + path (str): Path to save the token file. + """ + if not os.path.exists(os.path.dirname(path)): + # give an error if the directory does not exist + raise FileNotFoundError(f"Directory {os.path.dirname(path)} does not exist") + + with open(path, "w") as f: + f.write(self.token) + + def save_pubkey(self, path): + """Saves the public key to a file. + + Args: + path (str): Path to save the public key file. + """ + if not os.path.exists(os.path.dirname(path)): + # give an error if the directory does not exist + raise FileNotFoundError(f"Directory {os.path.dirname(path)} does not exist") + + with open(path, "wb") as f: + f.write(self.public_key) + + def save_cert(self, path): + """Saves the certificate to a file. + + Args: + path (str): Path to save the certificate file. + """ + if not os.path.exists(os.path.dirname(path)): + # give an error if the directory does not exist + raise FileNotFoundError(f"Directory {os.path.dirname(path)} does not exist") + + with open(path, "wb") as f: + f.write(self.cert) + + def save_private_key(self, path): + """Saves the private key to a file. + + Args: + path (str): Path to save the private key file. + """ + if not os.path.exists(os.path.dirname(path)): + # give an error if the directory does not exist + raise FileNotFoundError(f"Directory {os.path.dirname(path)} does not exist") + + with open(path, "wb") as f: + f.write(self.private_key) + + def build_root_cert(self, client_certs_path): + """Builds the root certificate for the attested identity. + + Returns: + str: Path to the root certificate. + """ + if not os.path.exists(client_certs_path): + raise FileNotFoundError(f"Client certs path {client_certs_path} does not exist") + + agg_cert = self.get_cert() + # find all .crt files in client cert path and add them to the root cert + client_certs = [] + for root, dirs, files in os.walk(client_certs_path): + for file in files: + if file.endswith(".crt"): + logger.info(f"Found client cert: {file} in {root}") + client_certs.append(os.path.join(root, file)) + + # create a root cert with the agg cert and client certs + root_cert_chain = os.path.join( + os.path.dirname(self.get_root_cert_path()), + "cert_chain.crt", + ) + with open(root_cert_chain, "wb") as f: + f.write(agg_cert) + for client_cert in client_certs: + with open(client_cert, "rb") as cf: + f.write(cf.read()) + logger.info(f"Root cert created at {root_cert_chain}") + # print the root cert + with open(root_cert_chain, "rb") as f: + root_cert = f.read() + logger.info(f"Root cert: {root_cert}") + + return self.root_cert_path + + def get_root_cert_path(self): + """Returns the path to the attestation report. + + Returns: + str: Path to the attestation report. + """ + return self.root_cert_path + + +class AttestationManager: + """Class to manage attestation for participant enclaves. + This class handles the generation of SGX quotes, fetching MRENCLAVE values, + and obtaining attestation reports from AVS (Attestation Verification Service). + """ + + def __init__( + self, + participant_name, + attestation_report_path, + ita_api_key, + avs_url, + root_cert_path, + privkey_path=None, + ): + """Initializes the AttestationManager with the provided parameters. + + Args: + participant_name (str): Name of the enclave. + attestation_report_path (str): Path to the store attestation reports. + ita_api_key (str): API key for ITA. + avs_url (str): URL for AVS. + """ + self.participant_name = participant_name + self.attestation_report_path = attestation_report_path + self.ita_api_key = ita_api_key + self.avs_url = avs_url + self.root_cert_path = root_cert_path + logger.info(f"root cert path: {self.root_cert_path}") + self.privkey_path = privkey_path + if self.privkey_path is None: + self.ecdsa_p_384_signer = ECDSASigner.get_instance() + else: + self.ecdsa_p_384_signer = ECDSASigner.get_instance(self.privkey_path) + + def get_attested_identity(self, cert_host="localhost"): + """Generates an attested identity, where a public/private key pair + is bound to the workload measurement (MRENCLAVE) via + a remote attestation report. + Args: + None + Raises: + FileNotFoundError: If the attestation report path does not exist. + ValueError: If the ITA API key or AVS URL is not set. + Returns: + dict: A dictionary containing the MRENCLAVE value, participant name, + and AVS report token. + """ + + if not os.path.exists(self.attestation_report_path): + raise FileNotFoundError( + f"attestation report path {self.attestation_report_path} does not exist" + ) + + if self.ita_api_key is None: + raise ValueError("ITA API key is required for remote attestation") + if self.avs_url is None: + raise ValueError("AVS URL is required for remote attestation") + + mrenclave = self.fetch_mrenclave_from_quote(None) + logger.info(f"Enclave MRENCLAVE: {mrenclave}") + + # Generate ECDSA-P 384 key pair and certificate + # This is the enclave specific ECDSA-P 384 keypair for setting up mTLS + # with collaborator enclave + + os.path.join(self.attestation_report_path, f"{self.participant_name}_privkey.pem") + cert = self.ecdsa_p_384_signer.cert(cert_host, mrenclave).encode("utf-8") + key = self.ecdsa_p_384_signer.serialize_private_key() + self_signed_cert_path = os.path.join( + self.attestation_report_path, f"{self.participant_name}_pubkey.pem" + ) + + self.pubkey = self.ecdsa_p_384_signer.get_pubkey() + avs_report = self.get_avs_report_ita( + self.participant_name, + self.attestation_report_path, + cert, + self.avs_url, + self.ita_api_key, + ) + logger.info(f"AVS report: {avs_report}") + attested_identity = AttestedIdentity( + mrenclave=mrenclave, + name=self.participant_name, + token=avs_report["token"], + cert=cert, + private_key=key, + public_key=self.pubkey, + root_cert_path=self.root_cert_path, + ) + # save public key to file + attested_identity.save_cert(self_signed_cert_path) + + return attested_identity + + def get_ecdsa_p_384_signer(self): + """Returns the ECDSA-P 384 signer instance. + + Returns: + ECDSASigner: The ECDSA-P 384 signer instance. + """ + return self.ecdsa_p_384_signer + + def gen_sgx_quote( + self, + user_data_bytes=None, + quote_dump_path="/tmp/quote.json", + orig_data=None, + challenge="0000000000000", + ): + """Generates the SGX quote and saves it to the specified path. + Args: + user_data_bytes (bytes): User data to be included in the quote. + quote_dump_path (str): Path to save the generated quote. + orig_data (bytes): Original data to be included in the quote. + challenge (str): Challenge string to be included in the quote. + Returns: + dict: A dictionary containing the generated quote. + Raises: + ValueError: If the user data is not a bytes object or exceeds 64 bytes. + """ + + # Set user data + if user_data_bytes is not None: + # Check that user data is a bytes like object + if isinstance(user_data_bytes, bytes) is not True: + raise ValueError("User data must be a bytes object") + + # Check that user data is at most 64 bytes + if len(list(user_data_bytes)) > 64: + raise ValueError("User data can be at most 64 bytes") + + # Set user report data + with open("/dev/attestation/user_report_data", "wb") as fh: + fh.write(user_data_bytes) + + # Generate attestation quote + quote = None + with open("/dev/attestation/quote", "rb") as fh: + quote = fh.read(8192) + + # Create quote as expected by AVS for verification + # a. base64 encode the quote + # b. Set 'userData' to empty, we do not want AVS to verify user data + quote_avs = {} + quote_avs["quote"] = base64.b64encode(quote).decode("utf-8") + + if orig_data is not None: + quote_avs["runtime_data"] = base64.b64encode(orig_data).decode("utf-8") + else: + quote_avs["runtime_data"] = "" + + # Save the quote in the specified location + with open(quote_dump_path, "w") as fh: + json.dump(quote_avs, fh) + + return quote_avs + + def fetch_mrenclave_from_quote(self, quote=None): + """Fetches the MRENCLAVE value from the SGX quote. + + Args: + quote (str): The SGX quote in JSON format. + + Returns: + str: The MRENCLAVE value extracted from the quote. + """ + + if quote is None: + # Generate attestation quote + with open("/dev/attestation/quote", "rb") as fh: + quote = fh.read(8192) + + # Create quote as expected by AVS for verification + # a. base64 encode the quote + # b. Set 'userData' to empty, we do not want AVS to verify user data + quote_avs = {} + quote_avs["quote"] = base64.b64encode(quote).decode("utf-8") + + # Decode the base64 encoded quote + decoded_quote = base64.b64decode(quote_avs["quote"]) + + # Extract the MRENCLAVE value from the decoded quote + mrenclave_hex = decoded_quote[112:144].hex() + logger.info(f"Extracted MRENCLAVE: {mrenclave_hex}") + + return mrenclave_hex + + def get_avs_report_ita(self, name, attestation_report_path, cert, avs_url, ita_api_key): + """Fetches the attestation report from AVS using the ITA API key. + This function generates the SGX quote, sends it to AVS for attestation, + and saves the attestation report in the specified path. + Args: + name (str): Name of the enclave. + attestation_report_path (str): Path to save the attestation report. + cert (bytes): Certificate for mTLS with the collaborator enclave. + avs_url (str): URL for AVS. + ita_api_key (str): API key for ITA. + Returns: + dict: A dictionary containing the attestation report from AVS. + Raises: + Exception: If the connection to AVS fails or if the response is not OK. + """ + # Set the paths for getting the enclave quote and ITA AVS response + quote_dir = os.path.normpath("/tmp") + quote_path = os.path.join(quote_dir, f"{name}_quote.json") + avs_report_path = os.path.join(attestation_report_path, f"{name}_avs_report.json") + + # AVS doesn't work with SHA384, so, moving to SHA256 in the meantime + cert_sha256_digest = hashlib.sha256(cert).digest() + self.gen_sgx_quote(cert_sha256_digest, quote_path, cert) + + # ITA attestation + avs_attest_endpoint = f"{avs_url}/appraisal/v1/attest" + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "x-api-key": ita_api_key, + } + + # Read the quote and send it to AVS for getting the report + with open(quote_path) as fh: + quote_avs = fh.read() + + # AVS has a self-signed certificate, so, disabling verification + res = requests.post(avs_attest_endpoint, headers=headers, data=quote_avs) + if res.status_code != http.HTTPStatus.OK: + raise Exception(f"Failed to connect with {avs_url}, err: {res.status_code}") + + avs_report = res.content + with open(avs_report_path, "wb") as fh: + fh.write(avs_report) + + # Register enclave with governor + avs_report = avs_report.decode("utf-8") + avs_report = json.loads(avs_report) + + return avs_report + + +def fetch_attestation_env_vars(): + """Fetches attestation environment variables from the system. + This function retrieves the ITA API key, AVS URL, and attestation report path + from the environment variables. If the attestation report path is not set, + it defaults to None. + Args: + None + Raises: + None + + Returns: + dict: A dictionary containing the attestation environment variables. + """ + env_vars = { + "ITA_API_KEY": os.getenv("ITA_API_KEY"), + "AVS_URL": os.getenv("AVS_URL"), + "ATTESTATION_REPORT_PATH": os.getenv("ATTESTATION_REPORT_PATH", None), + "ROOT_CERT_PATH": os.getenv("ROOT_CERT_PATH", None), + } + return env_vars + + +def get_remote_attestation(participant_name): + """Starts the remote attestation process for the participant enclave. + This function initializes the AttestationManager and generates an attested + identity for the participant enclave. + Args: + plan_config (str): Path to the plan configuration file. + participant_name (str): Name of the enclave. + + Returns: + AttestedIdentity: An instance of the AttestedIdentity class + containing the attested identity. + """ + # Fetch remote attestation environment variables + attestation_env = fetch_attestation_env_vars() + attested_identity = None + if attestation_env is not None: + attestation_manager = AttestationManager( + participant_name, + attestation_env["ATTESTATION_REPORT_PATH"], + attestation_env["ITA_API_KEY"], + attestation_env["AVS_URL"], + attestation_env["ROOT_CERT_PATH"], + ) + # Generate and store the attestation report + attested_identity = attestation_manager.get_attested_identity() + logger.info("Remote attestation report stored successfully.") + else: + logger.error("Remote attestation environment variables not set.") + return attested_identity diff --git a/openfl/utilities/secagg/__init__.py b/openfl/utilities/secagg/__init__.py index 6b75185a78..4a41a1df4f 100644 --- a/openfl/utilities/secagg/__init__.py +++ b/openfl/utilities/secagg/__init__.py @@ -1,13 +1,2 @@ # Copyright 2020-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - - -from openfl.utilities.secagg.crypto import ( - calculate_mask, - calculate_shared_mask, - create_ciphertext, - decipher_ciphertext, - pseudo_random_generator, -) -from openfl.utilities.secagg.key import generate_agreed_key, generate_key_pair -from openfl.utilities.secagg.shamir import create_secret_shares, reconstruct_secret diff --git a/tests/openfl/interface/test_aggregator_api.py b/tests/openfl/interface/test_aggregator_api.py index 3049a40c91..a390d4df05 100644 --- a/tests/openfl/interface/test_aggregator_api.py +++ b/tests/openfl/interface/test_aggregator_api.py @@ -26,6 +26,12 @@ def test_aggregator_start(mock_parse): 'assigner': { 'settings': { 'selected_task_group': 'learning' + + } + }, + 'aggregator': { + 'settings': { + 'enable_remote_attestation': False } } } diff --git a/tests/openfl/interface/test_collaborator.py b/tests/openfl/interface/test_collaborator.py index 654a58856c..fd55537614 100644 --- a/tests/openfl/interface/test_collaborator.py +++ b/tests/openfl/interface/test_collaborator.py @@ -16,7 +16,24 @@ def test_collaborator_start(mock_parse): plan_config = plan_path.joinpath('plan.yaml') data_config = plan_path.joinpath('data.yaml') - mock_parse.return_value = mock.Mock() + + + mock_plan = mock.MagicMock() + mock_plan.__getitem__.side_effect = {'task_group': 'learning'}.get + mock_plan.get = {'task_group': 'learning'}.get + mock_plan.config = { + 'assigner': { + 'settings': { + 'selected_task_group': 'learning' + } + }, + 'collaborator': { + 'settings': { + 'enable_remote_attestation': False + } + } + } + mock_parse.return_value = mock_plan ret = start_(['-p', plan_config, '-d', data_config,