Skip to content

Commit

Permalink
Update temporal client
Browse files Browse the repository at this point in the history
  • Loading branch information
topher-lo committed Dec 21, 2024
1 parent 6bd9d2b commit 023f2d6
Showing 1 changed file with 59 additions and 23 deletions.
82 changes: 59 additions & 23 deletions tracecat/dsl/client.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,74 @@
from dataclasses import dataclass

import aioboto3
from temporalio.client import Client
from temporalio.exceptions import TemporalError
from temporalio.service import TLSConfig
from tenacity import (
RetryError,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from tracecat import config
from tracecat.config import (
TEMPORAL__CLUSTER_NAMESPACE,
TEMPORAL__CLUSTER_URL,
TEMPORAL__CONNECT_RETRIES,
TEMPORAL__TLS_CERT__ARN,
TEMPORAL__TLS_ENABLED,
)
from tracecat.dsl._converter import pydantic_data_converter
from tracecat.logger import logger

_N_RETRIES = 10
_client: Client | None = None


@dataclass(frozen=True, slots=True)
class TemporalClientCert:
cert: bytes
private_key: bytes


async def _retrieve_client_cert(arn: str) -> TemporalClientCert:
"""Retrieve the client certificate and private key from AWS Secrets Manager."""
session = aioboto3.session.get_session()
async with session.client(service_name="secretsmanager") as client:
response = await client.get_secret_value(SecretId=arn)
secret = response["SecretString"]
return TemporalClientCert(
cert=secret["cert"].encode(),
private_key=secret["private_key"].encode(),
)


@retry(
stop=stop_after_attempt(_N_RETRIES),
stop=stop_after_attempt(TEMPORAL__CONNECT_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type(RuntimeError),
retry=retry_if_exception_type(TemporalError),
reraise=True,
)
async def connect_to_temporal() -> Client:
tls_config = False
if config.TEMPORAL__TLS_ENABLED:
if (
config.TEMPORAL__TLS_CLIENT_CERT is None
or config.TEMPORAL__TLS_CLIENT_PRIVATE_KEY is None
):
raise RuntimeError(
"TLS is enabled but no client certificate or private key is provided"

if TEMPORAL__TLS_ENABLED:
if not TEMPORAL__TLS_CERT__ARN:
raise ValueError(
"MTLS enabled for Temporal but `TEMPORAL__TLS_CERT_ARN` is not set"
)
logger.info("TLS enabled for Temporal")

logger.info("Retrieving Temporal MTLS client certificate...")
client_cert = await _retrieve_client_cert(arn=TEMPORAL__TLS_CERT__ARN)
logger.info("Successfully retrieved Temporal MTLS client certificate")
tls_config = TLSConfig(
client_cert=config.TEMPORAL__TLS_CLIENT_CERT.encode(),
client_private_key=config.TEMPORAL__TLS_CLIENT_PRIVATE_KEY.encode(),
client_cert=client_cert.cert,
client_private_key=client_cert.private_key,
)

client = await Client.connect(
target_host=config.TEMPORAL__CLUSTER_URL,
namespace=config.TEMPORAL__CLUSTER_NAMESPACE,
target_host=TEMPORAL__CLUSTER_URL,
namespace=TEMPORAL__CLUSTER_NAMESPACE,
tls=tls_config,
data_converter=pydantic_data_converter,
)
Expand All @@ -51,13 +80,20 @@ async def get_temporal_client() -> Client:

if _client is not None:
return _client
logger.info(f"Connecting to Temporal at {config.TEMPORAL__CLUSTER_URL}")

try:
logger.info(
"Connecting to Temporal server...",
namespace=TEMPORAL__CLUSTER_NAMESPACE,
url=TEMPORAL__CLUSTER_URL,
)
_client = await connect_to_temporal()
logger.info("Successfully connected to Temporal")
return _client
except Exception as e:
logger.error(
f"Failed to connect to Temporal after {_N_RETRIES} attempts: {str(e)}"
logger.info("Successfully connected to Temporal server")
except RetryError as e:
msg = (
f"Failed to connect to host {TEMPORAL__CLUSTER_URL} using namespace "
f"{TEMPORAL__CLUSTER_NAMESPACE} after {TEMPORAL__CONNECT_RETRIES} attempts. "
)
raise
raise RuntimeError(msg) from e
else:
return _client

0 comments on commit 023f2d6

Please sign in to comment.