diff --git a/deployments/aws/ecs/ecs-temporal-ui.tf b/deployments/aws/ecs/ecs-temporal-ui.tf index 328fbf54d..3e2af29e8 100644 --- a/deployments/aws/ecs/ecs-temporal-ui.tf +++ b/deployments/aws/ecs/ecs-temporal-ui.tf @@ -18,7 +18,7 @@ resource "aws_ecs_task_definition" "temporal_ui_task_definition" { container_definitions = jsonencode([ { name = "TemporalUiContainer" - image = "temporalio/ui:${var.temporal_ui_image_tag}" + image = "${var.temporal_ui_image}:${var.temporal_ui_image_tag}" portMappings = [ { containerPort = 8080 diff --git a/deployments/aws/ecs/ecs-temporal.tf b/deployments/aws/ecs/ecs-temporal.tf index 5d6dfa13f..4be309dd0 100644 --- a/deployments/aws/ecs/ecs-temporal.tf +++ b/deployments/aws/ecs/ecs-temporal.tf @@ -1,5 +1,6 @@ # ECS Task Definition for Temporal Service resource "aws_ecs_task_definition" "temporal_task_definition" { + count = var.disable_temporal_autosetup ? 0 : 1 family = "TracecatTemporalTaskDefinition" network_mode = "awsvpc" requires_compatibilities = ["FARGATE"] @@ -55,9 +56,10 @@ resource "aws_ecs_task_definition" "temporal_task_definition" { } resource "aws_ecs_service" "temporal_service" { + count = var.disable_temporal_autosetup ? 0 : 1 name = "temporal-server" cluster = aws_ecs_cluster.tracecat_cluster.id - task_definition = aws_ecs_task_definition.temporal_task_definition.arn + task_definition = aws_ecs_task_definition.temporal_task_definition[0].arn launch_type = "FARGATE" desired_count = 1 diff --git a/deployments/aws/ecs/locals.tf b/deployments/aws/ecs/locals.tf index 58faf9f1b..2824fdd56 100644 --- a/deployments/aws/ecs/locals.tf +++ b/deployments/aws/ecs/locals.tf @@ -12,10 +12,15 @@ locals { saml_acs_url = "https://${var.domain_name}/api/auth/saml/acs" internal_api_url = "http://api-service:8000" # Service connect DNS name internal_executor_url = "http://executor-service:8002" # Service connect DNS name - temporal_cluster_url = "temporal-service:7233" - temporal_cluster_queue = "tracecat-task-queue" + temporal_cluster_url = var.temporal_cluster_url + temporal_cluster_queue = var.temporal_cluster_queue + temporal_namespace = var.temporal_namespace allow_origins = "${var.domain_name},http://ui-service:3000" # Allow api service and public app to access the API + # Temporal client authentication + temporal_mtls_cert_arn = var.temporal_mtls_cert_arn + temporal_api_key_arn = var.temporal_api_key_arn + # Tracecat postgres env vars # See: https://github.com/TracecatHQ/tracecat/blob/abd5ff/tracecat/db/engine.py#L21 tracecat_db_configs = { @@ -31,8 +36,12 @@ locals { RUN_MIGRATIONS = "true" SAML_SP_ACS_URL = local.saml_acs_url TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout + TEMPORAL__CLUSTER_NAMESPACE = local.temporal_namespace TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue TEMPORAL__CLUSTER_URL = local.temporal_cluster_url + TEMPORAL__MTLS_ENABLED = var.temporal_mtls_enabled + TEMPORAL__MTLS_CERT__ARN = local.temporal_mtls_cert_arn + TEMPORAL__API_KEY__ARN = local.temporal_api_key_arn TRACECAT__ALLOW_ORIGINS = local.allow_origins TRACECAT__API_ROOT_PATH = "/api" TRACECAT__API_URL = local.internal_api_url @@ -40,11 +49,11 @@ locals { TRACECAT__AUTH_ALLOWED_DOMAINS = var.auth_allowed_domains TRACECAT__AUTH_TYPES = var.auth_types TRACECAT__DB_ENDPOINT = local.core_db_hostname - TRACECAT__PUBLIC_APP_URL = local.public_app_url + TRACECAT__EXECUTOR_URL = local.internal_executor_url TRACECAT__PUBLIC_API_URL = local.public_api_url + TRACECAT__PUBLIC_APP_URL = local.public_app_url TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME = var.remote_repository_package_name TRACECAT__REMOTE_REPOSITORY_URL = var.remote_repository_url - TRACECAT__EXECUTOR_URL = local.internal_executor_url }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] @@ -52,16 +61,20 @@ locals { worker_env = [ for k, v in merge({ LOG_LEVEL = var.log_level - TRACECAT__API_URL = local.internal_api_url + TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout + TEMPORAL__CLUSTER_NAMESPACE = local.temporal_namespace + TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue + TEMPORAL__CLUSTER_URL = local.temporal_cluster_url + TEMPORAL__MTLS_ENABLED = var.temporal_mtls_enabled + TEMPORAL__MTLS_CERT__ARN = local.temporal_mtls_cert_arn + TEMPORAL__API_KEY__ARN = local.temporal_api_key_arn TRACECAT__API_ROOT_PATH = "/api" + TRACECAT__API_URL = local.internal_api_url TRACECAT__APP_ENV = var.tracecat_app_env TRACECAT__DB_ENDPOINT = local.core_db_hostname - TRACECAT__PUBLIC_API_URL = local.public_api_url - TEMPORAL__CLUSTER_URL = local.temporal_cluster_url - TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue - TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout - TRACECAT__EXECUTOR_URL = local.internal_executor_url TRACECAT__EXECUTOR_CLIENT_TIMEOUT = var.executor_client_timeout + TRACECAT__EXECUTOR_URL = local.internal_executor_url + TRACECAT__PUBLIC_API_URL = local.public_api_url }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] diff --git a/deployments/aws/ecs/variables.tf b/deployments/aws/ecs/variables.tf index e212d5734..970a2426b 100644 --- a/deployments/aws/ecs/variables.tf +++ b/deployments/aws/ecs/variables.tf @@ -34,12 +34,6 @@ variable "allowed_inbound_cidr_blocks" { default = ["0.0.0.0/0"] } -# variable "allowed_outbound_cidr_blocks" { -# description = "List of CIDR blocks the ALB can send traffic to" -# type = list(string) -# default = [] # Empty by default, will be set to VPC CIDR -# } - variable "enable_waf" { description = "Whether to enable WAF for the ALB" type = bool @@ -80,6 +74,21 @@ variable "auth_allowed_domains" { ### Images and Versions +variable "tracecat_image" { + type = string + default = "ghcr.io/tracecathq/tracecat" +} + +variable "tracecat_ui_image" { + type = string + default = "ghcr.io/tracecathq/tracecat-ui" +} + +variable "tracecat_image_tag" { + type = string + default = "0.18.2" +} + variable "temporal_server_image" { type = string default = "temporalio/auto-setup" @@ -90,24 +99,25 @@ variable "temporal_server_image_tag" { default = "1.24.2" } -variable "temporal_ui_image_tag" { +variable "temporal_ui_image" { type = string - default = "2.32.0" + default = "temporalio/ui" } -variable "tracecat_image" { +variable "temporal_ui_image_tag" { type = string - default = "ghcr.io/tracecathq/tracecat" + default = "2.32.0" } -variable "tracecat_ui_image" { - type = string - default = "ghcr.io/tracecathq/tracecat-ui" +variable "force_new_deployment" { + type = bool + description = "Force a new deployment of Tracecat services. Used to update services with new images." + default = false } -variable "disable_temporal_ui" { +variable "use_git_commit_sha" { type = bool - description = "Whether to disable the Temporal UI service in the deployment" + description = "Use the git commit SHA as the image tag" default = false } @@ -117,23 +127,66 @@ variable "TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA" { default = null } -variable "tracecat_image_tag" { - type = string - default = "0.18.2" +### Temporal configuration + +variable "disable_temporal_ui" { + type = bool + description = "Whether to disable the Temporal UI service in the deployment" + default = false } -variable "use_git_commit_sha" { +variable "disable_temporal_autosetup" { type = bool - description = "Use the git commit SHA as the image tag" + description = "Whether to disable the Temporal auto-setup service in the deployment" default = false } -variable "force_new_deployment" { +variable "temporal_mtls_enabled" { type = bool - description = "Force a new deployment of Tracecat services. Used to update services with new images." + description = "Whether to enable MTLS for the Temporal client" default = false } +variable "temporal_cluster_url" { + type = string + description = "Host and port of the Temporal server to connect to" + default = "temporal-service:7233" +} + +variable "temporal_cluster_queue" { + type = string + description = "Temporal task queue to use for client calls" + default = "default" +} + +variable "temporal_namespace" { + type = string + description = "Temporal namespace to use for client calls" + default = "default" +} + + +### Container Env Vars +# NOTE: sensitive variables are stored in secrets manager +# and specified directly in the task definition via a secret reference + +variable "tracecat_app_env" { + type = string + description = "The environment of the Tracecat application" + default = "production" +} + +variable "log_level" { + type = string + description = "Log level for the application" + default = "INFO" +} + +variable "temporal_log_level" { + type = string + default = "warn" +} + ### Secret ARNs variable "tracecat_db_encryption_key_arn" { @@ -195,6 +248,20 @@ variable "temporal_auth_client_secret_arn" { default = null } +# Temporal client + +variable "temporal_mtls_cert_arn" { + type = string + description = "The ARN of the secret containing the Temporal client certificate (optional)" + default = null +} + +variable "temporal_api_key_arn" { + type = string + description = "The ARN of the secret containing the Temporal API key (optional)" + default = null +} + ### (Optional) Custom Integrations variable "remote_repository_package_name" { @@ -347,24 +414,3 @@ variable "rds_auto_minor_version_upgrade" { description = "Enable auto minor version upgrades for RDS instances" default = false } - -### Container Env Vars -# NOTE: sensitive variables are stored in secrets manager -# and specified directly in the task definition via a secret reference - -variable "tracecat_app_env" { - type = string - description = "The environment of the Tracecat application" - default = "production" -} - -variable "log_level" { - type = string - description = "Log level for the application" - default = "INFO" -} - -variable "temporal_log_level" { - type = string - default = "warn" -} diff --git a/deployments/aws/main.tf b/deployments/aws/main.tf index cea1d7933..12952ecb8 100644 --- a/deployments/aws/main.tf +++ b/deployments/aws/main.tf @@ -36,10 +36,24 @@ module "ecs" { hosted_zone_id = var.hosted_zone_id # Tracecat version - TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA = var.TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA + tracecat_image = var.tracecat_image + tracecat_ui_image = var.tracecat_ui_image tracecat_image_tag = var.tracecat_image_tag - use_git_commit_sha = var.use_git_commit_sha + temporal_server_image = var.temporal_server_image + temporal_server_image_tag = var.temporal_server_image_tag + temporal_ui_image = var.temporal_ui_image + temporal_ui_image_tag = var.temporal_ui_image_tag force_new_deployment = var.force_new_deployment + use_git_commit_sha = var.use_git_commit_sha + TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA = var.TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA + + # Temporal configuration + disable_temporal_ui = var.disable_temporal_ui + disable_temporal_autosetup = var.disable_temporal_autosetup + temporal_mtls_enabled = var.temporal_mtls_enabled + temporal_cluster_url = var.temporal_cluster_url + temporal_cluster_queue = var.temporal_cluster_queue + temporal_namespace = var.temporal_namespace # Container environment variables tracecat_app_env = var.tracecat_app_env @@ -71,11 +85,14 @@ module "ecs" { saml_idp_certificate_arn = var.saml_idp_certificate_arn saml_idp_metadata_url_arn = var.saml_idp_metadata_url_arn - # Temporal UI + # Temporal UI authentication temporal_auth_provider_url = var.temporal_auth_provider_url temporal_auth_client_id_arn = var.temporal_auth_client_id_arn temporal_auth_client_secret_arn = var.temporal_auth_client_secret_arn - disable_temporal_ui = var.disable_temporal_ui + + # Temporal client authentication + temporal_mtls_cert_arn = var.temporal_mtls_cert_arn + temporal_api_key_arn = var.temporal_api_key_arn # Compute / memory api_cpu = var.api_cpu diff --git a/deployments/aws/variables.tf b/deployments/aws/variables.tf index 011c10165..8b1ea4bef 100644 --- a/deployments/aws/variables.tf +++ b/deployments/aws/variables.tf @@ -33,10 +33,14 @@ variable "auth_allowed_domains" { ### Images and Versions -variable "TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA" { - description = "Terraform Cloud only: the git commit SHA of that triggered the run" - type = string - default = null +variable "tracecat_image" { + type = string + default = "ghcr.io/tracecathq/tracecat" +} + +variable "tracecat_ui_image" { + type = string + default = "ghcr.io/tracecathq/tracecat-ui" } variable "tracecat_image_tag" { @@ -44,10 +48,24 @@ variable "tracecat_image_tag" { default = "0.18.2" } -variable "use_git_commit_sha" { - type = bool - description = "Use the git commit SHA as the image tag" - default = false +variable "temporal_server_image" { + type = string + default = "temporalio/auto-setup" +} + +variable "temporal_server_image_tag" { + type = string + default = "1.24.2" +} + +variable "temporal_ui_image" { + type = string + default = "temporalio/ui" +} + +variable "temporal_ui_image_tag" { + type = string + default = "2.32.0" } variable "force_new_deployment" { @@ -56,12 +74,78 @@ variable "force_new_deployment" { default = false } +variable "use_git_commit_sha" { + type = bool + description = "Use the git commit SHA as the image tag" + default = false +} + +variable "TFC_CONFIGURATION_VERSION_GIT_COMMIT_SHA" { + description = "Terraform Cloud only: the git commit SHA of that triggered the run" + type = string + default = null +} + +### Temporal configuration + variable "disable_temporal_ui" { type = bool description = "Whether to disable the Temporal UI service in the deployment" default = false } +variable "disable_temporal_autosetup" { + type = bool + description = "Whether to disable the Temporal auto-setup service in the deployment" + default = false +} + +variable "temporal_mtls_enabled" { + type = bool + description = "Whether to enable MTLS for the Temporal client" + default = false +} + +variable "temporal_cluster_url" { + type = string + description = "Host and port of the Temporal server to connect to" + default = "temporal-service:7233" +} + +variable "temporal_cluster_queue" { + type = string + description = "Temporal task queue to use for client calls" + default = "default" +} + +variable "temporal_namespace" { + type = string + description = "Temporal namespace to use for client calls" + default = "default" +} + + +### Container Env Vars +# NOTE: sensitive variables are stored in secrets manager +# and specified directly in the task definition via a secret reference + +variable "tracecat_app_env" { + type = string + description = "The environment of the Tracecat application" + default = "production" +} + +variable "log_level" { + type = string + description = "Log level for the application" + default = "INFO" +} + +variable "temporal_log_level" { + type = string + default = "warn" +} + ### Secret ARNs variable "tracecat_db_encryption_key_arn" { @@ -103,6 +187,8 @@ variable "saml_idp_metadata_url_arn" { default = null } +# Temporal UI + variable "temporal_auth_provider_url" { type = string description = "The URL of the Temporal auth provider" @@ -121,6 +207,20 @@ variable "temporal_auth_client_secret_arn" { default = null } +# Temporal client + +variable "temporal_mtls_cert_arn" { + type = string + description = "The ARN of the secret containing the Temporal client certificate (optional)" + default = null +} + +variable "temporal_api_key_arn" { + type = string + description = "The ARN of the secret containing the Temporal API key (optional)" + default = null +} + ### (Optional) Custom Integrations variable "remote_repository_package_name" { @@ -237,24 +337,3 @@ variable "rds_backup_retention_period" { description = "The number of days to retain backups for RDS instances" default = 7 } - -### Container Env Vars -# NOTE: sensitive variables are stored in secrets manager -# and specified directly in the task definition via a secret reference - -variable "tracecat_app_env" { - type = string - description = "The environment of the Tracecat application" - default = "production" -} - -variable "log_level" { - type = string - description = "Log level for the application" - default = "INFO" -} - -variable "temporal_log_level" { - type = string - default = "warn" -} diff --git a/tracecat/config.py b/tracecat/config.py index 843369339..027e3e1f4 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -105,19 +105,19 @@ TRACECAT__ALLOW_ORIGINS = os.environ.get("TRACECAT__ALLOW_ORIGINS") # === Temporal config === # +TEMPORAL__CONNECT_RETRIES = int(os.environ.get("TEMPORAL__CONNECT_RETRIES", 10)) TEMPORAL__CLUSTER_URL = os.environ.get( "TEMPORAL__CLUSTER_URL", "http://localhost:7233" -) # AKA Temporal target host +) # AKA TEMPORAL_HOST_URL TEMPORAL__CLUSTER_NAMESPACE = os.environ.get( "TEMPORAL__CLUSTER_NAMESPACE", "default" -) # Temporal namespace +) # AKA TEMPORAL_NAMESPACE TEMPORAL__CLUSTER_QUEUE = os.environ.get( "TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue" -) # Temporal task queue -TEMPORAL__TLS_ENABLED = os.environ.get("TEMPORAL__TLS_ENABLED", False) -TEMPORAL__TLS_ENABLED = os.environ.get("TEMPORAL__TLS_ENABLED", False) -TEMPORAL__TLS_CLIENT_CERT = os.environ.get("TEMPORAL__TLS_CLIENT_CERT") -TEMPORAL__TLS_CLIENT_PRIVATE_KEY = os.environ.get("TEMPORAL__TLS_CLIENT_PRIVATE_KEY") +) +TEMPORAL__API_KEY__ARN = os.environ.get("TEMPORAL__API_KEY__ARN") +TEMPORAL__MTLS_ENABLED = os.environ.get("TEMPORAL__MTLS_ENABLED", "").lower() in ("1", "true") +TEMPORAL__MTLS_CERT__ARN = os.environ.get("TEMPORAL__MTLS_CERT__ARN") TEMPORAL__CLIENT_RPC_TIMEOUT = os.environ.get("TEMPORAL__CLIENT_RPC_TIMEOUT") """RPC timeout for Temporal workflows in seconds.""" diff --git a/tracecat/dsl/client.py b/tracecat/dsl/client.py index 2836c1250..cb30d1c37 100644 --- a/tracecat/dsl/client.py +++ b/tracecat/dsl/client.py @@ -1,45 +1,85 @@ +from dataclasses import dataclass + +import aioboto3 from temporalio.client import Client +from temporalio.exceptions import TemporalError from temporalio.service import TLSConfig from tenacity import ( + RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) -from tracecat import config +from tracecat.config import ( + TEMPORAL__API_KEY__ARN, + TEMPORAL__CLUSTER_NAMESPACE, + TEMPORAL__CLUSTER_URL, + TEMPORAL__CONNECT_RETRIES, + TEMPORAL__MTLS_CERT__ARN, + TEMPORAL__MTLS_ENABLED, +) from tracecat.dsl._converter import pydantic_data_converter from tracecat.logger import logger -_N_RETRIES = 10 _client: Client | None = None +@dataclass(frozen=True, slots=True) +class TemporalClientCert: + cert: bytes + private_key: bytes + + +async def _retrieve_temporal_client_cert(arn: str) -> TemporalClientCert: + """Retrieve the client certificate and private key from AWS Secrets Manager.""" + session = aioboto3.session.get_session() + async with session.client(service_name="secretsmanager") as client: + response = await client.get_secret_value(SecretId=arn) + secret = response["SecretString"] + return TemporalClientCert( + cert=secret["cert"].encode(), + private_key=secret["private_key"].encode(), + ) + + +async def _retrieve_temporal_api_key(arn: str) -> str: + """Retrieve the Temporal API key from AWS Secrets Manager.""" + session = aioboto3.session.get_session() + async with session.client(service_name="secretsmanager") as client: + response = await client.get_secret_value(SecretId=arn) + return response["SecretString"] + + @retry( - stop=stop_after_attempt(_N_RETRIES), + stop=stop_after_attempt(TEMPORAL__CONNECT_RETRIES), wait=wait_exponential(multiplier=1, min=1, max=10), - retry=retry_if_exception_type(RuntimeError), + retry=retry_if_exception_type(TemporalError), reraise=True, ) async def connect_to_temporal() -> Client: + api_key = None tls_config = False - if config.TEMPORAL__TLS_ENABLED: - if ( - config.TEMPORAL__TLS_CLIENT_CERT is None - or config.TEMPORAL__TLS_CLIENT_PRIVATE_KEY is None - ): - raise RuntimeError( - "TLS is enabled but no client certificate or private key is provided" + + if TEMPORAL__MTLS_ENABLED: + if not TEMPORAL__MTLS_CERT__ARN: + raise ValueError( + "MTLS enabled for Temporal but `TEMPORAL__MTLS_CERT_ARN` is not set" ) - logger.info("TLS enabled for Temporal") + client_cert = await _retrieve_temporal_client_cert(arn=TEMPORAL__MTLS_CERT__ARN) tls_config = TLSConfig( - client_cert=config.TEMPORAL__TLS_CLIENT_CERT.encode(), - client_private_key=config.TEMPORAL__TLS_CLIENT_PRIVATE_KEY.encode(), + client_cert=client_cert.cert, + client_private_key=client_cert.private_key, ) + if TEMPORAL__API_KEY__ARN: + api_key = await _retrieve_temporal_api_key(arn=TEMPORAL__API_KEY__ARN) + client = await Client.connect( - target_host=config.TEMPORAL__CLUSTER_URL, - namespace=config.TEMPORAL__CLUSTER_NAMESPACE, + target_host=TEMPORAL__CLUSTER_URL, + namespace=TEMPORAL__CLUSTER_NAMESPACE, + api_key=api_key, tls=tls_config, data_converter=pydantic_data_converter, ) @@ -51,13 +91,20 @@ async def get_temporal_client() -> Client: if _client is not None: return _client - logger.info(f"Connecting to Temporal at {config.TEMPORAL__CLUSTER_URL}") + try: + logger.info( + "Connecting to Temporal server...", + namespace=TEMPORAL__CLUSTER_NAMESPACE, + url=TEMPORAL__CLUSTER_URL, + ) _client = await connect_to_temporal() - logger.info("Successfully connected to Temporal") - return _client - except Exception as e: - logger.error( - f"Failed to connect to Temporal after {_N_RETRIES} attempts: {str(e)}" + logger.info("Successfully connected to Temporal server") + except RetryError as e: + msg = ( + f"Failed to connect to host {TEMPORAL__CLUSTER_URL} using namespace " + f"{TEMPORAL__CLUSTER_NAMESPACE} after {TEMPORAL__CONNECT_RETRIES} attempts. " ) - raise + raise RuntimeError(msg) from e + else: + return _client