-
Notifications
You must be signed in to change notification settings - Fork 24
/
generative_ai_txt2img_sagemaker_stack.py
88 lines (71 loc) · 4.21 KB
/
generative_ai_txt2img_sagemaker_stack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from aws_cdk import (
Stack,
aws_iam as iam,
aws_ssm as ssm,
)
from constructs import Construct
from construct.sagemaker_endpoint_construct import SageMakerEndpointConstruct
class GenerativeAiTxt2imgSagemakerStack(Stack):
def __init__(self, scope: Construct, construct_id: str, model_info, **kwargs) -> None:
super().__init__(scope, construct_id, **kwargs)
role = iam.Role(self, "Gen-AI-SageMaker-Policy", assumed_by=iam.ServicePrincipal("sagemaker.amazonaws.com"))
role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess"))
sts_policy = iam.Policy(self, "sm-deploy-policy-sts",
statements=[iam.PolicyStatement(
effect=iam.Effect.ALLOW,
actions=[
"sts:AssumeRole"
],
resources=["*"]
)]
)
logs_policy = iam.Policy(self, "sm-deploy-policy-logs",
statements=[iam.PolicyStatement(
effect=iam.Effect.ALLOW,
actions=[
"cloudwatch:PutMetricData",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:CreateLogGroup",
"logs:DescribeLogStreams",
"ecr:GetAuthorizationToken"
],
resources=["*"]
)]
)
ecr_policy = iam.Policy(self, "sm-deploy-policy-ecr",
statements=[iam.PolicyStatement(
effect=iam.Effect.ALLOW,
actions=[
"ecr:*",
],
resources=["*"]
)]
)
role.attach_inline_policy(sts_policy)
role.attach_inline_policy(logs_policy)
role.attach_inline_policy(ecr_policy)
endpoint = SageMakerEndpointConstruct(self, "TXT2IMG",
project_prefix = "GenerativeAiDemo",
role_arn= role.role_arn,
model_name = "StableDiffusionText2Img",
model_bucket_name = model_info["model_bucket_name"],
model_bucket_key = model_info["model_bucket_key"],
model_docker_image = model_info["model_docker_image"],
variant_name = "AllTraffic",
variant_weight = 1,
instance_count = 1,
instance_type = model_info["instance_type"],
environment = {
"MMS_MAX_RESPONSE_SIZE": "20000000",
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": model_info["region_name"],
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
},
deploy_enable = True
)
endpoint.node.add_dependency(sts_policy)
endpoint.node.add_dependency(logs_policy)
endpoint.node.add_dependency(ecr_policy)
ssm.StringParameter(self, "txt2img_sm_endpoint", parameter_name="txt2img_sm_endpoint", string_value=endpoint.endpoint_name)