Skip to content

Commit

Permalink
Handling specific exception codes on RCI call.
Browse files Browse the repository at this point in the history
  • Loading branch information
BinBin He committed Dec 23, 2024
1 parent 41d593c commit 88590d9
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 0 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions ecs-agent/api/ecs/client/ecs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -231,6 +232,16 @@ func (client *ecsClient) registerContainerInstance(clusterRef string, containerI
registerRequest.ClientToken = &registrationToken
resp, err := client.standardClient.RegisterContainerInstance(&registerRequest)
if err != nil {

// On error codes InvalidParameterException and Client exception, there is no need to retry as
// retries will not resolve the issue. Exit 5 to terminally stop agent process to avoid a retry storm
// due to agent restart.
if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) ||
utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeAccessDeniedException) {
os.Exit(5)
}
// All other exceptions is deferred to agent retry per existing retry strategy
// including throttlingExceptions.
logger.Error("Unable to register as a container instance with ECS", logger.Fields{
field.Error: err,
})
Expand Down
123 changes: 123 additions & 0 deletions ecs-agent/api/ecs/client/ecs_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,129 @@ func TestRegisterContainerInstance(t *testing.T) {
}
}

func TestExceptionsRegisterContainerInstance(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl)
additionalAttributes := map[string]string{
"my_custom_attribute": "Custom_Value1",
"my_other_custom_attribute": "Custom_Value2",
"attribute_name_with_no_value": "",
}
cfgAccessorOverrideFunc := func(cfgAccessor *mock_config.MockAgentConfigAccessor) {
cfgAccessor.EXPECT().InstanceAttributes().Return(additionalAttributes).AnyTimes()
}
tester := setup(t, ctrl, mockEC2Metadata, cfgAccessorOverrideFunc)

fakeCapabilities := []string{"capability1", "capability2"}
expectedAttributes := map[string]string{
"ecs.os-type": tester.mockCfgAccessor.OSType(),
"ecs.os-family": tester.mockCfgAccessor.OSFamily(),
"ecs.availability-zone": availabilityZone,
"ecs.outpost-arn": outpostARN,
}
for i := range fakeCapabilities {
expectedAttributes[fakeCapabilities[i]] = ""
}
capabilities := buildAttributeList(fakeCapabilities, nil)

commonExpectations := func() {
mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).
Return("instanceIdentityDocument", nil)
mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).
Return("signature", nil)
}

validateRequest := func(req *ecsmodel.RegisterContainerInstanceInput) {
assert.Equal(t, "arn:test", *req.ContainerInstanceArn, "Wrong container instance ARN")
assert.Equal(t, configuredCluster, *req.Cluster, "Wrong cluster")
assert.Equal(t, registrationToken, *req.ClientToken, "Wrong client token")
assert.Equal(t, iid, *req.InstanceIdentityDocument, "Wrong IID")
assert.Equal(t, iidSignature, *req.InstanceIdentityDocumentSignature, "Wrong IID sig")
assert.Equal(t, 4, len(req.TotalResources), "Wrong length of TotalResources")
resource, ok := findResource(req.TotalResources, "PORTS_UDP")
assert.True(t, ok, `Could not find resource "PORTS_UDP"`)
assert.Equal(t, "STRINGSET", *resource.Type, `Wrong type for resource "PORTS_UDP"`)
assert.Equal(t, 5, len(req.Attributes), "Wrong number of Attributes")
reqAttributes := func() map[string]string {
rv := make(map[string]string, len(req.Attributes))
for i := range req.Attributes {
rv[aws.StringValue(req.Attributes[i].Name)] = aws.StringValue(req.Attributes[i].Value)
}
return rv
}()
for k, v := range reqAttributes {
assert.Contains(t, expectedAttributes, k)
assert.Equal(t, expectedAttributes[k], v)
}
assert.Equal(t, len(containerInstanceTags), len(req.Tags), "Wrong number of tags")
reqTags := extractTagsMapFromRegisterContainerInstanceInput(req)
for k, v := range reqTags {
assert.Contains(t, containerInstanceTagsMap, k)
assert.Equal(t, containerInstanceTagsMap[k], v)
}
}

t.Run("Normal case", func(t *testing.T) {
commonExpectations()
tester.mockStandardClient.EXPECT().RegisterContainerInstance(gomock.Any()).
Do(validateRequest).
Return(&ecsmodel.RegisterContainerInstanceOutput{
ContainerInstance: &ecsmodel.ContainerInstance{
ContainerInstanceArn: aws.String(containerInstanceARN),
Attributes: buildAttributeList(fakeCapabilities, expectedAttributes),
},
}, nil)

arn, availabilityzone, err := tester.client.RegisterContainerInstance("arn:test", capabilities,
containerInstanceTags, registrationToken, nil, outpostARN)

assert.NoError(t, err)
assert.Equal(t, containerInstanceARN, arn)
assert.Equal(t, availabilityZone, availabilityzone, "availabilityZone is incorrect")
})

t.Run("InvalidParameterException", func(t *testing.T) {
commonExpectations()
tester.mockStandardClient.EXPECT().RegisterContainerInstance(gomock.Any()).
Do(validateRequest).
Return(nil, awserr.New(ecsmodel.ErrCodeInvalidParameterException, "Invalid parameter", nil))

defer func() {
if r := recover(); r != nil {
assert.Equal(t, 5, r)
} else {
t.Errorf("The code did not panic")
}
}()

_, _, err := tester.client.RegisterContainerInstance("arn:test", capabilities,
containerInstanceTags, registrationToken, nil, outpostARN)

t.Errorf("Expected panic, got error: %v", err)
})

t.Run("AccessDeniedException", func(t *testing.T) {
commonExpectations()
tester.mockStandardClient.EXPECT().RegisterContainerInstance(gomock.Any()).
Do(validateRequest).
Return(nil, awserr.New(ecsmodel.ErrCodeAccessDeniedException, "Access denied", nil))

defer func() {
if r := recover(); r != nil {
assert.Equal(t, 5, r)
} else {
t.Errorf("The code did not panic")
}
}()

_, _, err := tester.client.RegisterContainerInstance("arn:test", capabilities,
containerInstanceTags, registrationToken, nil, outpostARN)

t.Errorf("Expected panic, got error: %v", err)
})
}

func TestReRegisterContainerInstance(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down
9 changes: 9 additions & 0 deletions ecs-agent/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"golang.org/x/exp/constraints"
)

Expand Down Expand Up @@ -68,3 +69,11 @@ func MaxNum[T constraints.Integer | constraints.Float](a, b T) T {
}
return b
}

// IsAWSErrorCodeEqual returns true if the err implements Error
// interface of awserr and it has the same error code as
// the passed in error code.
func IsAWSErrorCodeEqual(err error, code string) bool {
awsErr, ok := err.(awserr.Error)
return ok && awsErr.Code() == code
}
33 changes: 33 additions & 0 deletions ecs-agent/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
package utils

import (
"errors"
"strconv"
"testing"

"github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -129,3 +132,33 @@ func testMaxNumFloat(t *testing.T) {
require.Equal(t, largerVal, MaxNum(largerVal, smallerVal))
require.Equal(t, largerVal, MaxNum(largerVal, largerVal))
}

func TestIsAWSErrorCodeEqual(t *testing.T) {
testcases := []struct {
name string
err error
res bool
}{
{
name: "Happy Path",
err: awserr.New(ecs.ErrCodeInvalidParameterException, "errMsg", errors.New("err")),
res: true,
},
{
name: "Wrong Error Code",
err: awserr.New("errCode", "errMsg", errors.New("err")),
res: false,
},
{
name: "Wrong Error Type",
err: errors.New("err"),
res: false,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.res, IsAWSErrorCodeEqual(tc.err, ecs.ErrCodeInvalidParameterException))
})
}
}

0 comments on commit 88590d9

Please sign in to comment.