diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go index 6e57ba6b27e..239975db500 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "net/http" + "os" "strings" "time" @@ -231,6 +232,16 @@ func (client *ecsClient) registerContainerInstance(clusterRef string, containerI registerRequest.ClientToken = ®istrationToken resp, err := client.standardClient.RegisterContainerInstance(®isterRequest) if err != nil { + + // On error codes InvalidParameterException and Client exception, there is no need to retry as + // retries will not resolve the issue. Exit 5 to terminally stop agent process to avoid a retry storm + // due to agent restart. + if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) || + utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeAccessDeniedException) { + os.Exit(5) + } + // All other exceptions is deferred to agent retry per existing retry strategy + // including throttlingExceptions. logger.Error("Unable to register as a container instance with ECS", logger.Fields{ field.Error: err, }) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go index ede4cda84c5..a11f8d2f5d4 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go @@ -17,6 +17,7 @@ import ( "strconv" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "golang.org/x/exp/constraints" ) @@ -68,3 +69,11 @@ func MaxNum[T constraints.Integer | constraints.Float](a, b T) T { } return b } + +// IsAWSErrorCodeEqual returns true if the err implements Error +// interface of awserr and it has the same error code as +// the passed in error code. +func IsAWSErrorCodeEqual(err error, code string) bool { + awsErr, ok := err.(awserr.Error) + return ok && awsErr.Code() == code +} diff --git a/ecs-agent/api/ecs/client/ecs_client.go b/ecs-agent/api/ecs/client/ecs_client.go index 6e57ba6b27e..239975db500 100644 --- a/ecs-agent/api/ecs/client/ecs_client.go +++ b/ecs-agent/api/ecs/client/ecs_client.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "net/http" + "os" "strings" "time" @@ -231,6 +232,16 @@ func (client *ecsClient) registerContainerInstance(clusterRef string, containerI registerRequest.ClientToken = ®istrationToken resp, err := client.standardClient.RegisterContainerInstance(®isterRequest) if err != nil { + + // On error codes InvalidParameterException and Client exception, there is no need to retry as + // retries will not resolve the issue. Exit 5 to terminally stop agent process to avoid a retry storm + // due to agent restart. + if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) || + utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeAccessDeniedException) { + os.Exit(5) + } + // All other exceptions is deferred to agent retry per existing retry strategy + // including throttlingExceptions. logger.Error("Unable to register as a container instance with ECS", logger.Fields{ field.Error: err, }) diff --git a/ecs-agent/api/ecs/client/ecs_client_test.go b/ecs-agent/api/ecs/client/ecs_client_test.go index 842e599c5f3..1ff135bdeb8 100644 --- a/ecs-agent/api/ecs/client/ecs_client_test.go +++ b/ecs-agent/api/ecs/client/ecs_client_test.go @@ -462,6 +462,129 @@ func TestRegisterContainerInstance(t *testing.T) { } } +func TestExceptionsRegisterContainerInstance(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) + additionalAttributes := map[string]string{ + "my_custom_attribute": "Custom_Value1", + "my_other_custom_attribute": "Custom_Value2", + "attribute_name_with_no_value": "", + } + cfgAccessorOverrideFunc := func(cfgAccessor *mock_config.MockAgentConfigAccessor) { + cfgAccessor.EXPECT().InstanceAttributes().Return(additionalAttributes).AnyTimes() + } + tester := setup(t, ctrl, mockEC2Metadata, cfgAccessorOverrideFunc) + + fakeCapabilities := []string{"capability1", "capability2"} + expectedAttributes := map[string]string{ + "ecs.os-type": tester.mockCfgAccessor.OSType(), + "ecs.os-family": tester.mockCfgAccessor.OSFamily(), + "ecs.availability-zone": availabilityZone, + "ecs.outpost-arn": outpostARN, + } + for i := range fakeCapabilities { + expectedAttributes[fakeCapabilities[i]] = "" + } + capabilities := buildAttributeList(fakeCapabilities, nil) + + commonExpectations := func() { + mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource). + Return("instanceIdentityDocument", nil) + mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource). + Return("signature", nil) + } + + validateRequest := func(req *ecsmodel.RegisterContainerInstanceInput) { + assert.Equal(t, "arn:test", *req.ContainerInstanceArn, "Wrong container instance ARN") + assert.Equal(t, configuredCluster, *req.Cluster, "Wrong cluster") + assert.Equal(t, registrationToken, *req.ClientToken, "Wrong client token") + assert.Equal(t, iid, *req.InstanceIdentityDocument, "Wrong IID") + assert.Equal(t, iidSignature, *req.InstanceIdentityDocumentSignature, "Wrong IID sig") + assert.Equal(t, 4, len(req.TotalResources), "Wrong length of TotalResources") + resource, ok := findResource(req.TotalResources, "PORTS_UDP") + assert.True(t, ok, `Could not find resource "PORTS_UDP"`) + assert.Equal(t, "STRINGSET", *resource.Type, `Wrong type for resource "PORTS_UDP"`) + assert.Equal(t, 5, len(req.Attributes), "Wrong number of Attributes") + reqAttributes := func() map[string]string { + rv := make(map[string]string, len(req.Attributes)) + for i := range req.Attributes { + rv[aws.StringValue(req.Attributes[i].Name)] = aws.StringValue(req.Attributes[i].Value) + } + return rv + }() + for k, v := range reqAttributes { + assert.Contains(t, expectedAttributes, k) + assert.Equal(t, expectedAttributes[k], v) + } + assert.Equal(t, len(containerInstanceTags), len(req.Tags), "Wrong number of tags") + reqTags := extractTagsMapFromRegisterContainerInstanceInput(req) + for k, v := range reqTags { + assert.Contains(t, containerInstanceTagsMap, k) + assert.Equal(t, containerInstanceTagsMap[k], v) + } + } + + t.Run("Normal case", func(t *testing.T) { + commonExpectations() + tester.mockStandardClient.EXPECT().RegisterContainerInstance(gomock.Any()). + Do(validateRequest). + Return(&ecsmodel.RegisterContainerInstanceOutput{ + ContainerInstance: &ecsmodel.ContainerInstance{ + ContainerInstanceArn: aws.String(containerInstanceARN), + Attributes: buildAttributeList(fakeCapabilities, expectedAttributes), + }, + }, nil) + + arn, availabilityzone, err := tester.client.RegisterContainerInstance("arn:test", capabilities, + containerInstanceTags, registrationToken, nil, outpostARN) + + assert.NoError(t, err) + assert.Equal(t, containerInstanceARN, arn) + assert.Equal(t, availabilityZone, availabilityzone, "availabilityZone is incorrect") + }) + + t.Run("InvalidParameterException", func(t *testing.T) { + commonExpectations() + tester.mockStandardClient.EXPECT().RegisterContainerInstance(gomock.Any()). + Do(validateRequest). + Return(nil, awserr.New(ecsmodel.ErrCodeInvalidParameterException, "Invalid parameter", nil)) + + defer func() { + if r := recover(); r != nil { + assert.Equal(t, 5, r) + } else { + t.Errorf("The code did not panic") + } + }() + + _, _, err := tester.client.RegisterContainerInstance("arn:test", capabilities, + containerInstanceTags, registrationToken, nil, outpostARN) + + t.Errorf("Expected panic, got error: %v", err) + }) + + t.Run("AccessDeniedException", func(t *testing.T) { + commonExpectations() + tester.mockStandardClient.EXPECT().RegisterContainerInstance(gomock.Any()). + Do(validateRequest). + Return(nil, awserr.New(ecsmodel.ErrCodeAccessDeniedException, "Access denied", nil)) + + defer func() { + if r := recover(); r != nil { + assert.Equal(t, 5, r) + } else { + t.Errorf("The code did not panic") + } + }() + + _, _, err := tester.client.RegisterContainerInstance("arn:test", capabilities, + containerInstanceTags, registrationToken, nil, outpostARN) + + t.Errorf("Expected panic, got error: %v", err) + }) +} + func TestReRegisterContainerInstance(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/ecs-agent/utils/utils.go b/ecs-agent/utils/utils.go index ede4cda84c5..a11f8d2f5d4 100644 --- a/ecs-agent/utils/utils.go +++ b/ecs-agent/utils/utils.go @@ -17,6 +17,7 @@ import ( "strconv" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "golang.org/x/exp/constraints" ) @@ -68,3 +69,11 @@ func MaxNum[T constraints.Integer | constraints.Float](a, b T) T { } return b } + +// IsAWSErrorCodeEqual returns true if the err implements Error +// interface of awserr and it has the same error code as +// the passed in error code. +func IsAWSErrorCodeEqual(err error, code string) bool { + awsErr, ok := err.(awserr.Error) + return ok && awsErr.Code() == code +} diff --git a/ecs-agent/utils/utils_test.go b/ecs-agent/utils/utils_test.go index 817c3be9290..f846fdf6661 100644 --- a/ecs-agent/utils/utils_test.go +++ b/ecs-agent/utils/utils_test.go @@ -13,10 +13,13 @@ package utils import ( + "errors" "strconv" "testing" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -129,3 +132,33 @@ func testMaxNumFloat(t *testing.T) { require.Equal(t, largerVal, MaxNum(largerVal, smallerVal)) require.Equal(t, largerVal, MaxNum(largerVal, largerVal)) } + +func TestIsAWSErrorCodeEqual(t *testing.T) { + testcases := []struct { + name string + err error + res bool + }{ + { + name: "Happy Path", + err: awserr.New(ecs.ErrCodeInvalidParameterException, "errMsg", errors.New("err")), + res: true, + }, + { + name: "Wrong Error Code", + err: awserr.New("errCode", "errMsg", errors.New("err")), + res: false, + }, + { + name: "Wrong Error Type", + err: errors.New("err"), + res: false, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.res, IsAWSErrorCodeEqual(tc.err, ecs.ErrCodeInvalidParameterException)) + }) + } +}