From 023f2d624b639a0d9260b6dc9565a20fee6c97dc Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 23:19:36 -0800 Subject: [PATCH] Update temporal client --- tracecat/dsl/client.py | 82 ++++++++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/tracecat/dsl/client.py b/tracecat/dsl/client.py index 2836c1250..f366d2b94 100644 --- a/tracecat/dsl/client.py +++ b/tracecat/dsl/client.py @@ -1,45 +1,74 @@ +from dataclasses import dataclass + +import aioboto3 from temporalio.client import Client +from temporalio.exceptions import TemporalError from temporalio.service import TLSConfig from tenacity import ( + RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) -from tracecat import config +from tracecat.config import ( + TEMPORAL__CLUSTER_NAMESPACE, + TEMPORAL__CLUSTER_URL, + TEMPORAL__CONNECT_RETRIES, + TEMPORAL__TLS_CERT__ARN, + TEMPORAL__TLS_ENABLED, +) from tracecat.dsl._converter import pydantic_data_converter from tracecat.logger import logger -_N_RETRIES = 10 _client: Client | None = None +@dataclass(frozen=True, slots=True) +class TemporalClientCert: + cert: bytes + private_key: bytes + + +async def _retrieve_client_cert(arn: str) -> TemporalClientCert: + """Retrieve the client certificate and private key from AWS Secrets Manager.""" + session = aioboto3.session.get_session() + async with session.client(service_name="secretsmanager") as client: + response = await client.get_secret_value(SecretId=arn) + secret = response["SecretString"] + return TemporalClientCert( + cert=secret["cert"].encode(), + private_key=secret["private_key"].encode(), + ) + + @retry( - stop=stop_after_attempt(_N_RETRIES), + stop=stop_after_attempt(TEMPORAL__CONNECT_RETRIES), wait=wait_exponential(multiplier=1, min=1, max=10), - retry=retry_if_exception_type(RuntimeError), + retry=retry_if_exception_type(TemporalError), reraise=True, ) async def connect_to_temporal() -> Client: tls_config = False - if config.TEMPORAL__TLS_ENABLED: - if ( - config.TEMPORAL__TLS_CLIENT_CERT is None - or config.TEMPORAL__TLS_CLIENT_PRIVATE_KEY is None - ): - raise RuntimeError( - "TLS is enabled but no client certificate or private key is provided" + + if TEMPORAL__TLS_ENABLED: + if not TEMPORAL__TLS_CERT__ARN: + raise ValueError( + "MTLS enabled for Temporal but `TEMPORAL__TLS_CERT_ARN` is not set" ) - logger.info("TLS enabled for Temporal") + + logger.info("Retrieving Temporal MTLS client certificate...") + client_cert = await _retrieve_client_cert(arn=TEMPORAL__TLS_CERT__ARN) + logger.info("Successfully retrieved Temporal MTLS client certificate") tls_config = TLSConfig( - client_cert=config.TEMPORAL__TLS_CLIENT_CERT.encode(), - client_private_key=config.TEMPORAL__TLS_CLIENT_PRIVATE_KEY.encode(), + client_cert=client_cert.cert, + client_private_key=client_cert.private_key, ) client = await Client.connect( - target_host=config.TEMPORAL__CLUSTER_URL, - namespace=config.TEMPORAL__CLUSTER_NAMESPACE, + target_host=TEMPORAL__CLUSTER_URL, + namespace=TEMPORAL__CLUSTER_NAMESPACE, tls=tls_config, data_converter=pydantic_data_converter, ) @@ -51,13 +80,20 @@ async def get_temporal_client() -> Client: if _client is not None: return _client - logger.info(f"Connecting to Temporal at {config.TEMPORAL__CLUSTER_URL}") + try: + logger.info( + "Connecting to Temporal server...", + namespace=TEMPORAL__CLUSTER_NAMESPACE, + url=TEMPORAL__CLUSTER_URL, + ) _client = await connect_to_temporal() - logger.info("Successfully connected to Temporal") - return _client - except Exception as e: - logger.error( - f"Failed to connect to Temporal after {_N_RETRIES} attempts: {str(e)}" + logger.info("Successfully connected to Temporal server") + except RetryError as e: + msg = ( + f"Failed to connect to host {TEMPORAL__CLUSTER_URL} using namespace " + f"{TEMPORAL__CLUSTER_NAMESPACE} after {TEMPORAL__CONNECT_RETRIES} attempts. " ) - raise + raise RuntimeError(msg) from e + else: + return _client