diff --git a/fbpcp/gateway/kms.py b/fbpcp/gateway/kms.py new file mode 100644 index 00000000..2f6efe88 --- /dev/null +++ b/fbpcp/gateway/kms.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from base64 import b64decode, b64encode +from typing import Any, Dict, List, Optional + +import boto3 +from botocore.client import BaseClient +from fbpcp.decorator.error_handler import error_handler +from fbpcp.gateway.aws import AWSGateway + + +class KMSGateway(AWSGateway): + def __init__( + self, + region: str, + access_key_id: Optional[str] = None, + access_key_data: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(region, access_key_id, access_key_data, config) + self.client: BaseClient = boto3.client( + "kms", region_name=self.region, **self.config + ) + + @error_handler + def sign( + self, + key_id: str, + message: str, + message_type: str, + grant_tokens: List[str], + signing_algorithm: str, + ) -> str: + response = self.client.sign( + KeyId=key_id, + Message=message.encode(), + MessageType=message_type, + GrantTokens=grant_tokens, + SigningAlgorithm=signing_algorithm, + ) + signature = b64encode(response["Signature"]).decode() + return signature + + @error_handler + def verify( + self, + key_id: str, + message: str, + message_type: str, + signature: str, + signing_algorithm: str, + grant_tokens: List[str], + ) -> bool: + b64_signature = b64decode(signature.encode()) + response = self.client.verify( + KeyId=key_id, + Message=message.encode(), + MessageType=message_type, + Signature=b64_signature, + SigningAlgorithm=signing_algorithm, + GrantTokens=grant_tokens, + ) + return response["SignatureValid"] diff --git a/fbpcp/service/key_management.py b/fbpcp/service/key_management.py new file mode 100644 index 00000000..e8b50ada --- /dev/null +++ b/fbpcp/service/key_management.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import abc + + +class KeyManagementService(abc.ABC): + @abc.abstractmethod + def sign(self, message: str) -> str: + pass + + @abc.abstractmethod + def verify(self, message: str, signature: str) -> bool: + pass diff --git a/fbpcp/service/key_management_aws.py b/fbpcp/service/key_management_aws.py new file mode 100644 index 00000000..75a58421 --- /dev/null +++ b/fbpcp/service/key_management_aws.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, List, Optional + +from fbpcp.gateway.kms import KMSGateway +from fbpcp.service.key_management import KeyManagementService + + +class AWSKeyManagementService(KeyManagementService): + key_id: str + signing_algorithm: str + grant_tokens: List[str] + + def __init__( + self, + region: str, + key_id: str, + signing_algorithm: Optional[str] = None, + grant_tokens: Optional[List[str]] = None, + access_key_id: Optional[str] = None, + access_key_data: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Args: + grant_tokens: Advertiser side specific, allows anyone with a grant token to have permisions to certain functions on KMS (Admin Controlled) + """ + self.kms_gateway = KMSGateway(region, access_key_id, access_key_data, config) + self.key_id = key_id + self.signing_algorithm = signing_algorithm if signing_algorithm else "" + self.grant_tokens = grant_tokens if grant_tokens else [] + + def sign(self, message: str, message_type: str = "RAW") -> str: + if not self.signing_algorithm: + raise ValueError("No Signing Algorithm Set") + signature = self.kms_gateway.sign( + key_id=self.key_id, + message=message, + message_type=message_type, + grant_tokens=self.grant_tokens, + signing_algorithm=self.signing_algorithm, + ) + return signature + + def verify(self, message: str, signature: str, message_type: str = "RAW") -> bool: + valid = self.kms_gateway.verify( + key_id=self.key_id, + message=message, + message_type=message_type, + signature=signature, + signing_algorithm=self.signing_algorithm, + grant_tokens=self.grant_tokens, + ) + return valid diff --git a/onedocker/script/cli/onedocker_cli.py b/onedocker/script/cli/onedocker_cli.py index 957b975f..4d2ef6b7 100644 --- a/onedocker/script/cli/onedocker_cli.py +++ b/onedocker/script/cli/onedocker_cli.py @@ -41,14 +41,14 @@ from onedocker.service.attestation import AttestationService logger = None -onedocker_svc = None -container_svc = None -onedocker_package_repo = None -onedocker_checksum_repo = None + attestation_svc = None +container_svc = None +onedocker_svc = None log_svc = None -task_definition = None -repository_path = None + +onedocker_checksum_repo = None +onedocker_package_repo = None DEFAULT_BINARY_VERSION = "latest" DEFAULT_TIMEOUT = 18000 @@ -171,12 +171,8 @@ def _build_log_service(config: Dict[str, Any]) -> LogService: return log_class(**config["constructor"]) -def _build_exe_s3_path(repository_path: str, package_name: str, version: str) -> str: - return f"{repository_path}{package_name}/{version}/{package_name.split('/')[-1]}" - - def main() -> None: - global container_svc, onedocker_svc, onedocker_package_repo, onedocker_checksum_repo, log_svc, logger, task_definition, repository_path, attestation_svc + global container_svc, onedocker_svc, onedocker_package_repo, onedocker_checksum_repo, log_svc, logger, attestation_svc s = schema.Schema( { "upload": bool, @@ -208,30 +204,46 @@ def main() -> None: version = ( arguments["--version"] if arguments["--version"] else DEFAULT_BINARY_VERSION ) - enable_attestation = arguments["--enable_attestation"] config = yaml.load(Path(arguments["--config"])).get("onedocker-cli") - task_definition = config["setting"]["task_definition"] - repository_path = config["setting"]["repository_path"] - checksum_repository_path = config["setting"].get("checksum_repository_path", "") - - attestation_svc = AttestationService() - storage_svc = _build_storage_service(config["dependency"]["StorageService"]) - container_svc = _build_container_service(config["dependency"]["ContainerService"]) - onedocker_svc = OneDockerService(container_svc, task_definition) - onedocker_package_repo = OneDockerPackageRepository(storage_svc, repository_path) - onedocker_checksum_repo = OneDockerChecksumRepository( - storage_svc, checksum_repository_path - ) - log_svc = _build_log_service(config["dependency"]["LogService"]) - status = "enabled" if enable_attestation else "disabled" - logger.info(f"Package tracking for package {package_name}: {version} is {status}") + if arguments["upload"] or arguments["show"]: + repository_path = config["setting"]["repository_path"] + checksum_repository_path = config["setting"].get("checksum_repository_path", "") + + storage_svc = _build_storage_service(config["dependency"]["StorageService"]) + + onedocker_package_repo = OneDockerPackageRepository( + storage_svc, repository_path + ) + if checksum_repository_path: + onedocker_checksum_repo = OneDockerChecksumRepository( + storage_svc, checksum_repository_path + ) + + if arguments["test"] or arguments["stop"]: + task_definition = config["setting"]["task_definition"] + container_svc = _build_container_service( + config["dependency"]["ContainerService"] + ) + onedocker_svc = OneDockerService(container_svc, task_definition) if arguments["upload"]: + enable_attestation = arguments["--enable_attestation"] + + status = "enabled" if enable_attestation else "disabled" + logger.info( + f"Package tracking for package {package_name}: {version} is {status}" + ) + + attestation_svc = AttestationService() + _upload(package_dir, package_name, version, enable_attestation) elif arguments["test"]: - timeout = arguments["--timeout"] if arguments["--timeout"] else DEFAULT_TIMEOUT + timeout = arguments.get("--timeout", DEFAULT_TIMEOUT) + + log_svc = _build_log_service(config["dependency"]["LogService"]) + _test(package_name, version, arguments["--cmd_args"], timeout) elif arguments["show"]: _show(package_name, arguments["--version"]) diff --git a/tests/gateway/test_kms.py b/tests/gateway/test_kms.py new file mode 100644 index 00000000..c7e299c4 --- /dev/null +++ b/tests/gateway/test_kms.py @@ -0,0 +1,66 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import base64 +import base64 as b64 +import unittest +from unittest.mock import MagicMock, patch + +from fbpcp.gateway.kms import KMSGateway + + +class TestKMSGateway(unittest.TestCase): + REGION = "us-west-2" + TEST_ACCESS_KEY_ID = "test-access-key-id" + TEST_ACCESS_KEY_DATA = "test-access-key-data" + + @patch("boto3.client") + def setUp(self, BotoClient) -> None: + self.kms = KMSGateway( + region=self.REGION, + access_key_id=self.TEST_ACCESS_KEY_ID, + access_key_data=self.TEST_ACCESS_KEY_DATA, + ) + self.kms.client = BotoClient() + + def test_sign(self) -> None: + # Arrange + sign_args = { + "key_id": "test_key_id", + "message": "test_message", + "message_type": "test_message_type", + "grant_tokens": [], + "signing_algorithm": "", + } + signed_message = "test_signed_message" + self.kms.client.sign = MagicMock( + return_value={"Signature": signed_message.encode()} + ) + + # Act + b64_signature = self.kms.sign(**sign_args) + signature = b64.b64decode(b64_signature.encode()).decode() + + # Assert + self.assertEqual(signature, signed_message) + + def test_verify(self) -> None: + # Arrange + verify_args = { + "key_id": "test_key_id", + "message": "test_message", + "message_type": "test_message_type", + "signature": "dGVzdF9tZXNzYWdl", + "grant_tokens": [], + "signing_algorithm": "", + } + self.kms.client.verify = MagicMock(return_value={"SignatureValid": True}) + + # Act + verification = self.kms.verify(**verify_args) + + # Assert + self.assertTrue(verification) diff --git a/tests/service/test_key_managment_aws.py b/tests/service/test_key_managment_aws.py new file mode 100644 index 00000000..55e39948 --- /dev/null +++ b/tests/service/test_key_managment_aws.py @@ -0,0 +1,63 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import MagicMock, patch + +from fbpcp.service.key_management_aws import AWSKeyManagementService + + +class TestAWSKeyManagementService(unittest.TestCase): + REGION = "us-west-2" + + TEST_KEY_ID = "test-key-id" + TEST_SIGNING_ALGORITHM = "test-signing-algorithm" + + TEST_ACCESS_KEY_ID = "test-access-key-id" + TEST_ACCESS_KEY_DATA = "test-access-key-data" + + @patch("boto3.client") + def setUp(self, BotoClient) -> None: + self.kms_aws = AWSKeyManagementService( + region=self.REGION, + key_id=self.TEST_KEY_ID, + signing_algorithm=self.TEST_SIGNING_ALGORITHM, + access_key_id=self.TEST_ACCESS_KEY_ID, + access_key_data=self.TEST_ACCESS_KEY_DATA, + ) + self.kms_aws.kms_gateway.client = BotoClient() + + def test_sign(self) -> None: + # Arrange + sign_args = { + "message": "test_message", + "message_type": "test_message_type", + } + test_signature = "test_signature" + + self.kms_aws.kms_gateway.sign = MagicMock(return_value=test_signature) + + # Act + signature = self.kms_aws.sign(**sign_args) + + # Assert + self.assertEqual(signature, test_signature) + + def test_verify(self) -> None: + # Arrange + verify_args = { + "signature": "test_signature", + "message": "test_message", + "message_type": "test_message_type", + } + + self.kms_aws.kms_gateway.verify = MagicMock(return_value=True) + + # Act + status = self.kms_aws.verify(**verify_args) + + # Assert + self.assertTrue(status)