Skip to content

Commit

Permalink
Support Temporal API key
Browse files Browse the repository at this point in the history
  • Loading branch information
topher-lo committed Dec 21, 2024
1 parent 023f2d6 commit d008232
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
5 changes: 3 additions & 2 deletions tracecat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@
TEMPORAL__CLUSTER_QUEUE = os.environ.get(
"TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue"
)
TEMPORAL__TLS_ENABLED = os.environ.get("TEMPORAL__TLS_ENABLED", False) == "true"
TEMPORAL__TLS_CERT__ARN = os.environ.get("TRACECAT__TLS_CERT_ARN")
TEMPORAL__API_KEY__ARN = os.environ.get("TEMPORAL__API_KEY__ARN")
TEMPORAL__MTLS_ENABLED = os.environ.get("TEMPORAL__MTLS_ENABLED", False) == "true"
TEMPORAL__MTLS_CERT__ARN = os.environ.get("TEMPORAL__MTLS_CERT__ARN")
TEMPORAL__CLIENT_RPC_TIMEOUT = os.environ.get("TEMPORAL__CLIENT_RPC_TIMEOUT")
"""RPC timeout for Temporal workflows in seconds."""

Expand Down
31 changes: 21 additions & 10 deletions tracecat/dsl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
)

from tracecat.config import (
TEMPORAL__API_KEY__ARN,
TEMPORAL__CLUSTER_NAMESPACE,
TEMPORAL__CLUSTER_URL,
TEMPORAL__CONNECT_RETRIES,
TEMPORAL__TLS_CERT__ARN,
TEMPORAL__TLS_ENABLED,
TEMPORAL__MTLS_CERT__ARN,
TEMPORAL__MTLS_ENABLED,
)
from tracecat.dsl._converter import pydantic_data_converter
from tracecat.logger import logger
Expand All @@ -31,7 +32,7 @@ class TemporalClientCert:
private_key: bytes


async def _retrieve_client_cert(arn: str) -> TemporalClientCert:
async def _retrieve_temporal_client_cert(arn: str) -> TemporalClientCert:
"""Retrieve the client certificate and private key from AWS Secrets Manager."""
session = aioboto3.session.get_session()
async with session.client(service_name="secretsmanager") as client:
Expand All @@ -43,32 +44,42 @@ async def _retrieve_client_cert(arn: str) -> TemporalClientCert:
)


async def _retrieve_temporal_api_key(arn: str) -> str:
"""Retrieve the Temporal API key from AWS Secrets Manager."""
session = aioboto3.session.get_session()
async with session.client(service_name="secretsmanager") as client:
response = await client.get_secret_value(SecretId=arn)
return response["SecretString"]


@retry(
stop=stop_after_attempt(TEMPORAL__CONNECT_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type(TemporalError),
reraise=True,
)
async def connect_to_temporal() -> Client:
api_key = None
tls_config = False

if TEMPORAL__TLS_ENABLED:
if not TEMPORAL__TLS_CERT__ARN:
if TEMPORAL__MTLS_ENABLED:
if not TEMPORAL__MTLS_CERT__ARN:
raise ValueError(
"MTLS enabled for Temporal but `TEMPORAL__TLS_CERT_ARN` is not set"
"MTLS enabled for Temporal but `TEMPORAL__MTLS_CERT_ARN` is not set"
)

logger.info("Retrieving Temporal MTLS client certificate...")
client_cert = await _retrieve_client_cert(arn=TEMPORAL__TLS_CERT__ARN)
logger.info("Successfully retrieved Temporal MTLS client certificate")
client_cert = await _retrieve_temporal_client_cert(arn=TEMPORAL__MTLS_CERT__ARN)
tls_config = TLSConfig(
client_cert=client_cert.cert,
client_private_key=client_cert.private_key,
)

if TEMPORAL__API_KEY__ARN:
api_key = await _retrieve_temporal_api_key(arn=TEMPORAL__API_KEY__ARN)

client = await Client.connect(
target_host=TEMPORAL__CLUSTER_URL,
namespace=TEMPORAL__CLUSTER_NAMESPACE,
api_key=api_key,
tls=tls_config,
data_converter=pydantic_data_converter,
)
Expand Down

0 comments on commit d008232

Please sign in to comment.