diff --git a/tracecat/config.py b/tracecat/config.py index 4ff314104..4e08a263c 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -115,8 +115,9 @@ TEMPORAL__CLUSTER_QUEUE = os.environ.get( "TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue" ) -TEMPORAL__TLS_ENABLED = os.environ.get("TEMPORAL__TLS_ENABLED", False) == "true" -TEMPORAL__TLS_CERT__ARN = os.environ.get("TRACECAT__TLS_CERT_ARN") +TEMPORAL__API_KEY__ARN = os.environ.get("TEMPORAL__API_KEY__ARN") +TEMPORAL__MTLS_ENABLED = os.environ.get("TEMPORAL__MTLS_ENABLED", False) == "true" +TEMPORAL__MTLS_CERT__ARN = os.environ.get("TEMPORAL__MTLS_CERT__ARN") TEMPORAL__CLIENT_RPC_TIMEOUT = os.environ.get("TEMPORAL__CLIENT_RPC_TIMEOUT") """RPC timeout for Temporal workflows in seconds.""" diff --git a/tracecat/dsl/client.py b/tracecat/dsl/client.py index f366d2b94..cb30d1c37 100644 --- a/tracecat/dsl/client.py +++ b/tracecat/dsl/client.py @@ -13,11 +13,12 @@ ) from tracecat.config import ( + TEMPORAL__API_KEY__ARN, TEMPORAL__CLUSTER_NAMESPACE, TEMPORAL__CLUSTER_URL, TEMPORAL__CONNECT_RETRIES, - TEMPORAL__TLS_CERT__ARN, - TEMPORAL__TLS_ENABLED, + TEMPORAL__MTLS_CERT__ARN, + TEMPORAL__MTLS_ENABLED, ) from tracecat.dsl._converter import pydantic_data_converter from tracecat.logger import logger @@ -31,7 +32,7 @@ class TemporalClientCert: private_key: bytes -async def _retrieve_client_cert(arn: str) -> TemporalClientCert: +async def _retrieve_temporal_client_cert(arn: str) -> TemporalClientCert: """Retrieve the client certificate and private key from AWS Secrets Manager.""" session = aioboto3.session.get_session() async with session.client(service_name="secretsmanager") as client: @@ -43,6 +44,14 @@ async def _retrieve_client_cert(arn: str) -> TemporalClientCert: ) +async def _retrieve_temporal_api_key(arn: str) -> str: + """Retrieve the Temporal API key from AWS Secrets Manager.""" + session = aioboto3.session.get_session() + async with session.client(service_name="secretsmanager") as client: + response = await client.get_secret_value(SecretId=arn) + return response["SecretString"] + + @retry( stop=stop_after_attempt(TEMPORAL__CONNECT_RETRIES), wait=wait_exponential(multiplier=1, min=1, max=10), @@ -50,25 +59,27 @@ async def _retrieve_client_cert(arn: str) -> TemporalClientCert: reraise=True, ) async def connect_to_temporal() -> Client: + api_key = None tls_config = False - if TEMPORAL__TLS_ENABLED: - if not TEMPORAL__TLS_CERT__ARN: + if TEMPORAL__MTLS_ENABLED: + if not TEMPORAL__MTLS_CERT__ARN: raise ValueError( - "MTLS enabled for Temporal but `TEMPORAL__TLS_CERT_ARN` is not set" + "MTLS enabled for Temporal but `TEMPORAL__MTLS_CERT_ARN` is not set" ) - - logger.info("Retrieving Temporal MTLS client certificate...") - client_cert = await _retrieve_client_cert(arn=TEMPORAL__TLS_CERT__ARN) - logger.info("Successfully retrieved Temporal MTLS client certificate") + client_cert = await _retrieve_temporal_client_cert(arn=TEMPORAL__MTLS_CERT__ARN) tls_config = TLSConfig( client_cert=client_cert.cert, client_private_key=client_cert.private_key, ) + if TEMPORAL__API_KEY__ARN: + api_key = await _retrieve_temporal_api_key(arn=TEMPORAL__API_KEY__ARN) + client = await Client.connect( target_host=TEMPORAL__CLUSTER_URL, namespace=TEMPORAL__CLUSTER_NAMESPACE, + api_key=api_key, tls=tls_config, data_converter=pydantic_data_converter, )