Skip to content

Commit aafa854

Browse files
authored
feat(workflows): use SES tenants (#40612)
1 parent 40cfa5f commit aafa854

File tree

4 files changed

+411
-21
lines changed

4 files changed

+411
-21
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import logging
2+
from collections.abc import Iterable
3+
4+
from django.conf import settings
5+
from django.core.management.base import BaseCommand
6+
from django.core.paginator import Paginator
7+
from django.db.models import Q
8+
9+
import boto3
10+
from botocore.exceptions import BotoCoreError, ClientError
11+
12+
from posthog.models.integration import Integration
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def _batched(iterable: Iterable, size: int) -> Iterable[list]:
18+
batch: list = []
19+
for item in iterable:
20+
batch.append(item)
21+
if len(batch) >= size:
22+
yield batch
23+
batch = []
24+
if batch:
25+
yield batch
26+
27+
28+
def migrate_ses_tenants(team_ids: list[int], domains: list[str], dry_run: bool = False):
29+
"""
30+
Ensure existing SES email identities have SES Tenants and Tenant Resource Associations.
31+
32+
The command is idempotent.
33+
"""
34+
if team_ids and domains:
35+
print("Please provide either team_ids or domains, not both") # noqa: T201
36+
return
37+
38+
query = (
39+
Integration.objects.filter(kind="email")
40+
.filter(Q(config__provider="ses") | Q(config__provider__isnull=True))
41+
.order_by("id")
42+
)
43+
44+
if team_ids:
45+
print("Setting up SES tenants for teams:", team_ids) # noqa: T201
46+
query = query.filter(team_id__in=team_ids)
47+
elif domains:
48+
print("Setting up SES tenants for domains:", domains) # noqa: T201
49+
# Domains are stored in Integration.config["domain"]
50+
query = query.filter(config__domain__in=domains)
51+
else:
52+
print("Setting up SES tenants for all SES email identities") # noqa: T201
53+
54+
# Collect unique (team_id, domain) pairs to avoid duplicate work per domain
55+
pairs: list[tuple[int, str]] = []
56+
paginator = Paginator(query, 200)
57+
58+
for page_num in paginator.page_range:
59+
page = paginator.page(page_num)
60+
for integration in page.object_list:
61+
domain = integration.config.get("domain")
62+
if not domain:
63+
continue
64+
provider = integration.config.get("provider", "mailjet")
65+
if provider != "ses":
66+
continue
67+
pair = (integration.team_id, domain)
68+
if pair not in pairs:
69+
pairs.append(pair)
70+
71+
if not pairs:
72+
print("No SES email identities found to migrate.") # noqa: T201
73+
return
74+
75+
sts_client = boto3.client(
76+
"sts",
77+
)
78+
tenant_client = boto3.client(
79+
"sesv2",
80+
)
81+
82+
try:
83+
aws_account_id = sts_client.get_caller_identity()["Account"]
84+
except (ClientError, BotoCoreError) as e:
85+
logger.exception("Failed to get AWS account id for SES tenant association: %s", e)
86+
print("Error determining AWS account ID. Aborting.") # noqa: T201
87+
return
88+
89+
for batch in _batched(pairs, 50):
90+
for team_id, domain in batch:
91+
tenant_name = f"team-{team_id}"
92+
identity_arn = f"arn:aws:ses:{settings.SES_REGION}:{aws_account_id}:identity/{domain}"
93+
94+
# Create tenant if missing
95+
try:
96+
if dry_run:
97+
print(f"[DRY-RUN] Would ensure tenant '{tenant_name}' exists") # noqa: T201
98+
else:
99+
try:
100+
tenant_client.create_tenant(
101+
TenantName=tenant_name,
102+
Tags=[{"Key": "team_id", "Value": str(team_id)}],
103+
)
104+
print(f"Created SES tenant '{tenant_name}'") # noqa: T201
105+
except ClientError as e:
106+
if e.response.get("Error", {}).get("Code") == "AlreadyExistsException":
107+
print(f"Tenant '{tenant_name}' already exists") # noqa: T201
108+
else:
109+
raise
110+
except (ClientError, BotoCoreError) as e:
111+
logger.exception("Error creating SES tenant '%s': %s", tenant_name, e)
112+
print(f"Error creating tenant '{tenant_name}': {e}") # noqa: T201
113+
continue
114+
115+
# Create association if missing
116+
try:
117+
if dry_run:
118+
print(f"[DRY-RUN] Would associate identity '{identity_arn}' with tenant '{tenant_name}'") # noqa: T201
119+
else:
120+
try:
121+
tenant_client.create_tenant_resource_association(
122+
TenantName=tenant_name,
123+
ResourceArn=identity_arn,
124+
)
125+
print(f"Associated identity '{domain}' with tenant '{tenant_name}'") # noqa: T201
126+
except ClientError as e:
127+
if e.response.get("Error", {}).get("Code") == "AlreadyExistsException":
128+
print(f"Association already exists for '{domain}' and tenant '{tenant_name}'") # noqa: T201
129+
else:
130+
raise
131+
except (ClientError, BotoCoreError) as e:
132+
logger.exception(
133+
"Error creating SES tenant_resource_association for '%s' on '%s': %s",
134+
domain,
135+
tenant_name,
136+
e,
137+
)
138+
print(f"Error creating tenant_resource_association for '{domain}' on '{tenant_name}': {e}") # noqa: T201
139+
continue
140+
141+
142+
class Command(BaseCommand):
143+
help = "Migrate existing SES identities to use SES Tenants and resource associations"
144+
145+
def add_arguments(self, parser):
146+
parser.add_argument(
147+
"--dry-run",
148+
action="store_true",
149+
help="If set, will not perform changes, only print actions",
150+
)
151+
parser.add_argument(
152+
"--team-ids",
153+
type=str,
154+
help="Comma separated list of team ids to migrate",
155+
)
156+
parser.add_argument(
157+
"--domains",
158+
type=str,
159+
help="Comma separated list of email domains to migrate (e.g., example.com,foo.bar)",
160+
)
161+
162+
def handle(self, *args, **options):
163+
dry_run: bool = bool(options.get("dry_run"))
164+
team_ids_opt = options.get("team_ids")
165+
domains_opt = options.get("domains")
166+
167+
team_ids = [int(x) for x in team_ids_opt.split(",")] if team_ids_opt else []
168+
domains = [x.strip() for x in domains_opt.split(",")] if domains_opt else []
169+
170+
migrate_ses_tenants(team_ids=team_ids, domains=domains, dry_run=dry_run)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from posthog.test.base import BaseTest
2+
from unittest.mock import patch
3+
4+
from django.test import override_settings
5+
6+
from posthog.management.commands.migrate_ses_tenants import migrate_ses_tenants
7+
from posthog.models.integration import Integration
8+
9+
10+
class _FakeSESv2Client:
11+
def __init__(self):
12+
self.created_tenants: list[str] = []
13+
self.associations: list[tuple[str, str]] = []
14+
15+
def get_caller_identity(self):
16+
return {"Account": "123456789012"}
17+
18+
def create_tenant(self, TenantName: str, Tags: list[dict]): # noqa: N803
19+
# emulate idempotency externally in test assertions
20+
if TenantName in self.created_tenants:
21+
from botocore.exceptions import ClientError
22+
23+
raise ClientError({"Error": {"Code": "AlreadyExistsException", "Message": "Tenant exists"}}, "CreateTenant")
24+
self.created_tenants.append(TenantName)
25+
return {"TenantName": TenantName}
26+
27+
def create_tenant_resource_association(self, TenantName: str, ResourceArn: str): # noqa: N803
28+
# emulate idempotency externally in test assertions
29+
pair = (TenantName, ResourceArn)
30+
if pair in self.associations:
31+
from botocore.exceptions import ClientError
32+
33+
raise ClientError(
34+
{"Error": {"Code": "AlreadyExistsException", "Message": "Association exists"}},
35+
"CreateTenantResourceAssociation",
36+
)
37+
self.associations.append(pair)
38+
return {"TenantName": TenantName, "ResourceArn": ResourceArn}
39+
40+
41+
class TestMigrateSESTenants(BaseTest):
42+
def setUp(self):
43+
super().setUp()
44+
# Two SES email integrations on the same domain (should dedupe by (team, domain))
45+
Integration.objects.create(
46+
team=self.team,
47+
kind="email",
48+
integration_id="[email protected]",
49+
config={"domain": "example.com", "provider": "ses"},
50+
created_by=self.user,
51+
)
52+
Integration.objects.create(
53+
team=self.team,
54+
kind="email",
55+
integration_id="[email protected]",
56+
config={"domain": "example.com", "provider": "ses"},
57+
created_by=self.user,
58+
)
59+
# Non-SES provider should be ignored
60+
Integration.objects.create(
61+
team=self.team,
62+
kind="email",
63+
integration_id="[email protected]",
64+
config={"domain": "other.com", "provider": "mailjet"},
65+
created_by=self.user,
66+
)
67+
68+
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="us-east-1", SES_ENDPOINT="")
69+
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
70+
def test_dry_run(self, mock_boto_client):
71+
# Arrange stub clients
72+
sesv2 = _FakeSESv2Client()
73+
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
74+
75+
# Act: dry-run should not attempt create calls but will still resolve account id
76+
migrate_ses_tenants(team_ids=[], domains=[], dry_run=True)
77+
78+
# Assert: no tenants/associations performed
79+
assert sesv2.created_tenants == []
80+
assert sesv2.associations == []
81+
82+
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="us-east-1", SES_ENDPOINT="")
83+
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
84+
def test_migrate_for_team(self, mock_boto_client):
85+
sesv2 = _FakeSESv2Client()
86+
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
87+
88+
migrate_ses_tenants(team_ids=[self.team.id], domains=[], dry_run=False)
89+
90+
# Deduped: only one tenant and one association for (team, example.com)
91+
assert sesv2.created_tenants == [f"team-{self.team.id}"]
92+
expected_arn = f"arn:aws:ses:us-east-1:123456789012:identity/example.com"
93+
assert sesv2.associations == [(f"team-{self.team.id}", expected_arn)]
94+
95+
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="eu-west-1", SES_ENDPOINT="")
96+
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
97+
def test_migrate_for_domain_filter(self, mock_boto_client):
98+
sesv2 = _FakeSESv2Client()
99+
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
100+
101+
# Use domains filter; should match example.com only
102+
migrate_ses_tenants(team_ids=[], domains=["example.com"], dry_run=False)
103+
104+
assert sesv2.created_tenants == [f"team-{self.team.id}"]
105+
expected_arn = f"arn:aws:ses:eu-west-1:123456789012:identity/example.com"
106+
assert sesv2.associations == [(f"team-{self.team.id}", expected_arn)]
107+
108+
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="us-east-1", SES_ENDPOINT="")
109+
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
110+
def test_idempotent_on_repeated_run(self, mock_boto_client):
111+
sesv2 = _FakeSESv2Client()
112+
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
113+
114+
# First run creates
115+
migrate_ses_tenants(team_ids=[self.team.id], domains=[], dry_run=False)
116+
# Second run should hit AlreadyExistsException internally and not error
117+
migrate_ses_tenants(team_ids=[self.team.id], domains=[], dry_run=False)
118+
119+
# Still only one tenant and association recorded
120+
assert sesv2.created_tenants == [f"team-{self.team.id}"]
121+
expected_arn = f"arn:aws:ses:us-east-1:123456789012:identity/example.com"
122+
assert sesv2.associations == [(f"team-{self.team.id}", expected_arn)]

products/workflows/backend/providers/ses.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,48 @@
1212

1313
class SESProvider:
1414
def __init__(self):
15-
# Initialize SES client
16-
self.client = boto3.client(
15+
# Initialize the boto3 clients
16+
self.sts_client = boto3.client(
17+
"sts",
18+
aws_access_key_id=settings.SES_ACCESS_KEY_ID,
19+
aws_secret_access_key=settings.SES_SECRET_ACCESS_KEY,
20+
region_name=settings.SES_REGION,
21+
)
22+
self.ses_client = boto3.client(
1723
"ses",
1824
aws_access_key_id=settings.SES_ACCESS_KEY_ID,
1925
aws_secret_access_key=settings.SES_SECRET_ACCESS_KEY,
2026
region_name=settings.SES_REGION,
21-
endpoint_url=settings.SES_ENDPOINT if settings.SES_ENDPOINT else None,
27+
)
28+
self.ses_v2_client = boto3.client(
29+
"sesv2",
30+
aws_access_key_id=settings.SES_ACCESS_KEY_ID,
31+
aws_secret_access_key=settings.SES_SECRET_ACCESS_KEY,
32+
region_name=settings.SES_REGION,
2233
)
2334

2435
def create_email_domain(self, domain: str, team_id: int):
25-
# NOTE: For sesv1 creation is done through verification
36+
# NOTE: For sesv1, domain Identity creation is done through verification
2637
self.verify_email_domain(domain, team_id)
2738

39+
# Create a tenant for the domain if not exists
40+
tenant_name = f"team-{team_id}"
41+
try:
42+
self.ses_v2_client.create_tenant(TenantName=tenant_name, Tags=[{"Key": "team_id", "Value": str(team_id)}])
43+
except ClientError as e:
44+
if e.response["Error"]["Code"] != "AlreadyExistsException":
45+
raise
46+
47+
# Associate the new domain identity with the tenant
48+
try:
49+
self.ses_v2_client.create_tenant_resource_association(
50+
TenantName=tenant_name,
51+
ResourceArn=f"arn:aws:ses:{settings.SES_REGION}:{self.sts_client.get_caller_identity()['Account']}:identity/{domain}",
52+
)
53+
except ClientError as e:
54+
if e.response["Error"]["Code"] != "AlreadyExistsException":
55+
raise
56+
2857
def verify_email_domain(self, domain: str, team_id: int):
2958
# Validate the domain contains valid characters for a domain name
3059
DOMAIN_REGEX = r"(?i)^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$"
@@ -36,7 +65,7 @@ def verify_email_domain(self, domain: str, team_id: int):
3665
# Start/ensure domain verification (TXT at _amazonses.domain) ---
3766
verification_token = None
3867
try:
39-
resp = self.client.verify_domain_identity(Domain=domain)
68+
resp = self.ses_client.verify_domain_identity(Domain=domain)
4069
verification_token = resp.get("VerificationToken")
4170
except ClientError as e:
4271
# If already requested/exists, carry on; SES v1 is idempotent-ish here
@@ -57,7 +86,7 @@ def verify_email_domain(self, domain: str, team_id: int):
5786
# Start/ensure DKIM (three CNAMEs) ---
5887
dkim_tokens: list[str] = []
5988
try:
60-
resp = self.client.verify_domain_dkim(Domain=domain)
89+
resp = self.ses_client.verify_domain_dkim(Domain=domain)
6190
dkim_tokens = resp.get("DkimTokens", []) or []
6291
except ClientError as e:
6392
if e.response["Error"]["Code"] not in ("InvalidParameterValue",):
@@ -86,15 +115,15 @@ def verify_email_domain(self, domain: str, team_id: int):
86115

87116
# Current verification / DKIM statuses to compute overall status & per-record statuses ---
88117
try:
89-
id_attrs = self.client.get_identity_verification_attributes(Identities=[domain])
118+
id_attrs = self.ses_client.get_identity_verification_attributes(Identities=[domain])
90119
verification_status = (
91120
id_attrs["VerificationAttributes"].get(domain, {}).get("VerificationStatus", "Unknown")
92121
)
93122
except ClientError:
94123
verification_status = "Unknown"
95124

96125
try:
97-
dkim_attrs = self.client.get_identity_dkim_attributes(Identities=[domain])
126+
dkim_attrs = self.ses_client.get_identity_dkim_attributes(Identities=[domain])
98127
dkim_status = dkim_attrs["DkimAttributes"].get(domain, {}).get("DkimVerificationStatus", "Unknown")
99128
except ClientError:
100129
dkim_status = "Unknown"
@@ -131,7 +160,7 @@ def delete_identity(self, identity: str):
131160
Delete an identity from SES
132161
"""
133162
try:
134-
self.client.delete_identity(Identity=identity)
163+
self.ses_client.delete_identity(Identity=identity)
135164
logger.info(f"Identity {identity} deleted from SES")
136165
except (ClientError, BotoCoreError) as e:
137166
logger.exception(f"SES API error deleting identity: {e}")

0 commit comments

Comments
 (0)