From 6fe572f382bbd3a45b302fafade20629cec5f552 Mon Sep 17 00:00:00 2001
From: BinBin He <binbinhe@amazon.com>
Date: Fri, 27 Dec 2024 17:52:17 +0000
Subject: [PATCH] Terminally exit on unrecoverable exceptions for RCI.

---
 agent/app/agent.go      |  43 ++++++---
 agent/app/agent_test.go | 187 +++++++++++++++++++++++++++++++++++++---
 agent/app/errors.go     |  13 +--
 3 files changed, 217 insertions(+), 26 deletions(-)

diff --git a/agent/app/agent.go b/agent/app/agent.go
index de61e74fa34..c52a2a06b55 100644
--- a/agent/app/agent.go
+++ b/agent/app/agent.go
@@ -468,10 +468,15 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre
 	// Register the container instance
 	err = agent.registerContainerInstance(client, vpcSubnetAttributes)
 	if err != nil {
-		if isTransient(err) {
-			return exitcodes.ExitError
+		if isTerminal(err) {
+			// On unrecoverable error codes, agent should terminally exit.
+			logger.Critical("Agent will terminally exit, unable to register container instance:", logger.Fields{
+				field.Error: err,
+			})
+			return exitcodes.ExitTerminal
 		}
-		return exitcodes.ExitTerminal
+		// Other errors are considered recoverable and will be retried.
+		return exitcodes.ExitError
 	}
 
 	// Load Managed Daemon images asynchronously
@@ -855,13 +860,19 @@ func (agent *ecsAgent) registerContainerInstance(
 			field.Error: err,
 		})
 		if retriable, ok := err.(apierrors.Retriable); ok && !retriable.Retry() {
-			return err
+			return terminalError{err}
 		}
 		if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) {
 			logger.Critical("Instance registration attempt with an invalid parameter", logger.Fields{
 				field.Error: err,
 			})
-			return err
+			return terminalError{err}
+		}
+		if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeClientException) {
+			logger.Critical("Instance registration attempt with client performing invalid action", logger.Fields{
+				field.Error: err,
+			})
+			return terminalError{err}
 		}
 		if _, ok := err.(apierrors.AttributeError); ok {
 			attributeErrorMsg := ""
@@ -871,9 +882,9 @@ func (agent *ecsAgent) registerContainerInstance(
 			logger.Critical("Instance registration attempt with invalid attribute(s)", logger.Fields{
 				field.Error: attributeErrorMsg,
 			})
-			return err
+			return terminalError{err}
 		}
-		return transientError{err}
+		return err
 	}
 	logger.Info("Instance registration completed successfully", logger.Fields{
 		"instanceArn": containerInstanceArn,
@@ -903,7 +914,19 @@ func (agent *ecsAgent) reregisterContainerInstance(client ecs.ECSClient, capabil
 	})
 	if apierrors.IsInstanceTypeChangedError(err) {
 		seelog.Criticalf(instanceTypeMismatchErrorFormat, err)
-		return err
+		return terminalError{err}
+	}
+	if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) {
+		logger.Critical("Instance re-registration attempt with an invalid parameter", logger.Fields{
+			field.Error: err,
+		})
+		return terminalError{err}
+	}
+	if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeClientException) {
+		logger.Critical("Instance re-registration attempt with client performing invalid action", logger.Fields{
+			field.Error: err,
+		})
+		return terminalError{err}
 	}
 	if _, ok := err.(apierrors.AttributeError); ok {
 		attributeErrorMsg := ""
@@ -913,9 +936,9 @@ func (agent *ecsAgent) reregisterContainerInstance(client ecs.ECSClient, capabil
 		logger.Critical("Instance re-registration attempt with invalid attribute(s)", logger.Fields{
 			field.Error: attributeErrorMsg,
 		})
-		return err
+		return terminalError{err}
 	}
-	return transientError{err}
+	return err
 }
 
 // startAsyncRoutines starts all background methods
diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go
index 9dc8bbafa36..911fb11bfd7 100644
--- a/agent/app/agent_test.go
+++ b/agent/app/agent_test.go
@@ -760,6 +760,7 @@ func TestNewTaskEngineRestoreFromCheckpointNewStateManagerError(t *testing.T) {
 
 	ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl)
 	mockPauseLoader := mock_loader.NewMockLoader(ctrl)
+	newError := errors.New("error")
 	gomock.InOrder(
 		saveableOptionFactory.EXPECT().AddSaveable("TaskEngine", gomock.Any()).Return(nil),
 		saveableOptionFactory.EXPECT().AddSaveable("ContainerInstanceArn", gomock.Any()).Return(nil),
@@ -770,7 +771,7 @@ func TestNewTaskEngineRestoreFromCheckpointNewStateManagerError(t *testing.T) {
 
 		stateManagerFactory.EXPECT().NewStateManager(gomock.Any(), gomock.Any(),
 			gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
-			nil, errors.New("error")),
+			nil, newError),
 	)
 
 	dataClient := newTestDataClient(t)
@@ -795,7 +796,6 @@ func TestNewTaskEngineRestoreFromCheckpointNewStateManagerError(t *testing.T) {
 	_, _, err := agent.newTaskEngine(eventstream.NewEventStream("events", ctx),
 		credentialsManager, dockerstate.NewTaskEngineState(), imageManager, hostResources, execCmdMgr, serviceConnectManager, daemonManagers)
 	assert.Error(t, err)
-	assert.False(t, isTransient(err))
 }
 
 func TestNewTaskEngineRestoreFromCheckpointStateLoadError(t *testing.T) {
@@ -808,6 +808,7 @@ func TestNewTaskEngineRestoreFromCheckpointStateLoadError(t *testing.T) {
 	cfg.Checkpoint = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled}
 	ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl)
 	mockPauseLoader := mock_loader.NewMockLoader(ctrl)
+	newError := errors.New("error")
 
 	gomock.InOrder(
 		saveableOptionFactory.EXPECT().AddSaveable("TaskEngine", gomock.Any()).Return(nil),
@@ -819,7 +820,7 @@ func TestNewTaskEngineRestoreFromCheckpointStateLoadError(t *testing.T) {
 		stateManagerFactory.EXPECT().NewStateManager(gomock.Any(),
 			gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
 		).Return(stateManager, nil),
-		stateManager.EXPECT().Load().Return(errors.New("error")),
+		stateManager.EXPECT().Load().Return(newError),
 	)
 
 	dataClient := newTestDataClient(t)
@@ -844,7 +845,6 @@ func TestNewTaskEngineRestoreFromCheckpointStateLoadError(t *testing.T) {
 	_, _, err := agent.newTaskEngine(eventstream.NewEventStream("events", ctx),
 		credentialsManager, dockerstate.NewTaskEngineState(), imageManager, hostResources, execCmdMgr, serviceConnectManager, daemonManagers)
 	assert.Error(t, err)
-	assert.False(t, isTransient(err))
 }
 
 func TestNewTaskEngineRestoreFromCheckpoint(t *testing.T) {
@@ -1086,7 +1086,7 @@ func TestReregisterContainerInstanceInstanceTypeChanged(t *testing.T) {
 
 	err := agent.registerContainerInstance(client, nil)
 	assert.Error(t, err)
-	assert.False(t, isTransient(err))
+	assert.True(t, isTerminal(err))
 }
 
 func TestReregisterContainerInstanceAttributeError(t *testing.T) {
@@ -1145,7 +1145,7 @@ func TestReregisterContainerInstanceAttributeError(t *testing.T) {
 
 	err := agent.registerContainerInstance(client, nil)
 	assert.Error(t, err)
-	assert.False(t, isTransient(err))
+	assert.True(t, isTerminal(err))
 }
 
 func TestReregisterContainerInstanceNonTerminalError(t *testing.T) {
@@ -1204,7 +1204,66 @@ func TestReregisterContainerInstanceNonTerminalError(t *testing.T) {
 
 	err := agent.registerContainerInstance(client, nil)
 	assert.Error(t, err)
-	assert.True(t, isTransient(err))
+	assert.False(t, isTerminal(err))
+}
+
+func TestReregisterContainerInstanceTerminalError(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl)
+	client := mock_ecs.NewMockECSClient(ctrl)
+	mockCredentialsProvider := app_mocks.NewMockCredentialsProvider(ctrl)
+	mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl)
+	mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl)
+	mockPauseLoader := mock_loader.NewMockLoader(ctrl)
+
+	mockPauseLoader.EXPECT().IsLoaded(gomock.Any()).Return(false, nil).AnyTimes()
+	mockPauseLoader.EXPECT().LoadImage(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+	mockServiceConnectManager := mock_serviceconnect.NewMockManager(ctrl)
+	mockServiceConnectManager.EXPECT().IsLoaded(gomock.Any()).Return(true, nil).AnyTimes()
+	mockServiceConnectManager.EXPECT().GetLoadedAppnetVersion().AnyTimes()
+	mockServiceConnectManager.EXPECT().GetCapabilitiesForAppnetInterfaceVersion("").AnyTimes()
+	mockServiceConnectManager.EXPECT().SetECSClient(gomock.Any(), gomock.Any()).AnyTimes()
+
+	mockDaemonManager := mock_daemonmanager.NewMockDaemonManager(ctrl)
+	mockDaemonManagers := map[string]dm.DaemonManager{md.EbsCsiDriver: mockDaemonManager}
+	mockDaemonManager.EXPECT().IsLoaded(gomock.Any()).Return(true, nil).AnyTimes()
+	mockDaemonManager.EXPECT().LoadImage(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+
+	gomock.InOrder(
+		mockCredentialsProvider.EXPECT().Retrieve(gomock.Any()).Return(awsv2.Credentials{}, nil),
+		mockDockerClient.EXPECT().SupportedVersions().Return(apiVersions),
+		mockMobyPlugins.EXPECT().Scan().AnyTimes().Return([]string{}, nil),
+		mockDockerClient.EXPECT().ListPluginsWithFilters(gomock.Any(), gomock.Any(),
+			gomock.Any(), gomock.Any()).AnyTimes().Return([]string{}, nil),
+		client.EXPECT().RegisterContainerInstance(containerInstanceARN, gomock.Any(), gomock.Any(), gomock.Any(),
+			gomock.Any(), gomock.Any()).Return("", "", awserr.New("ClientException", "", nil)),
+	)
+	mockEC2Metadata.EXPECT().OutpostARN().Return("", nil)
+
+	cfg := getTestConfig()
+	cfg.Cluster = clusterName
+	ctx, cancel := context.WithCancel(context.TODO())
+	// Cancel the context to cancel async routines
+	defer cancel()
+	agent := &ecsAgent{
+		ctx:                   ctx,
+		cfg:                   &cfg,
+		dockerClient:          mockDockerClient,
+		ec2MetadataClient:     mockEC2Metadata,
+		pauseLoader:           mockPauseLoader,
+		credentialsCache:      awsv2.NewCredentialsCache(mockCredentialsProvider),
+		mobyPlugins:           mockMobyPlugins,
+		serviceconnectManager: mockServiceConnectManager,
+		daemonManagers:        mockDaemonManagers,
+	}
+	agent.containerInstanceARN = containerInstanceARN
+	agent.availabilityZone = availabilityZone
+
+	err := agent.registerContainerInstance(client, nil)
+	assert.Error(t, err)
+	assert.True(t, isTerminal(err))
 }
 
 func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetHappyPath(t *testing.T) {
@@ -1320,7 +1379,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCanRetryError(
 
 	err := agent.registerContainerInstance(client, nil)
 	assert.Error(t, err)
-	assert.True(t, isTransient(err))
+	assert.False(t, isTerminal(err))
 }
 
 func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCannotRetryError(t *testing.T) {
@@ -1378,7 +1437,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCannotRetryErr
 
 	err := agent.registerContainerInstance(client, nil)
 	assert.Error(t, err)
-	assert.False(t, isTransient(err))
+	assert.True(t, isTerminal(err))
 }
 
 func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetAttributeError(t *testing.T) {
@@ -1435,7 +1494,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetAttributeError
 
 	err := agent.registerContainerInstance(client, nil)
 	assert.Error(t, err)
-	assert.False(t, isTransient(err))
+	assert.True(t, isTerminal(err))
 }
 
 func TestRegisterContainerInstanceInvalidParameterTerminalError(t *testing.T) {
@@ -1499,6 +1558,114 @@ func TestRegisterContainerInstanceInvalidParameterTerminalError(t *testing.T) {
 		credentialsManager, state, imageManager, client, execCmdMgr)
 	assert.Equal(t, exitcodes.ExitTerminal, exitCode)
 }
+
+func TestRegisterContainerInstanceExceptionErrors(t *testing.T) {
+	testCases := []struct {
+		name     string
+		regError error
+		exitCode int
+	}{
+		{
+			name:     "InvalidParameterException",
+			regError: awserr.New("InvalidParameterException", "", nil),
+			exitCode: exitcodes.ExitTerminal,
+		},
+		{
+			name:     "ClientException",
+			regError: awserr.New("ClientException", "", nil),
+			exitCode: exitcodes.ExitTerminal,
+		},
+		{
+			name:     "ThrottlingException",
+			regError: awserr.New("ThrottlingException", "", nil),
+			exitCode: exitcodes.ExitError,
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			ctrl, credentialsManager, state, imageManager, client,
+				dockerClient, _, _, execCmdMgr, _ := setup(t)
+			defer ctrl.Finish()
+
+			mockCredentialsProvider := app_mocks.NewMockCredentialsProvider(ctrl)
+			mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl)
+			mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl)
+			mockPauseLoader := mock_loader.NewMockLoader(ctrl)
+			mockServiceConnectManager := mock_serviceconnect.NewMockManager(ctrl)
+			mockDaemonManager := mock_daemonmanager.NewMockDaemonManager(ctrl)
+			mockDaemonManagers := map[string]dm.DaemonManager{md.EbsCsiDriver: mockDaemonManager}
+
+			mockPauseLoader.EXPECT().IsLoaded(gomock.Any()).Return(false, nil).AnyTimes()
+			mockPauseLoader.EXPECT().LoadImage(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+
+			mockEC2Metadata.EXPECT().PrimaryENIMAC().Return("mac", nil)
+			mockEC2Metadata.EXPECT().VPCID("mac").Return("vpc-id", nil)
+			mockEC2Metadata.EXPECT().SubnetID("mac").Return("subnet-id", nil)
+			mockEC2Metadata.EXPECT().OutpostARN().Return("", nil)
+
+			mockServiceConnectManager.EXPECT().IsLoaded(gomock.Any()).Return(true, nil).AnyTimes()
+			mockServiceConnectManager.EXPECT().GetLoadedAppnetVersion().AnyTimes()
+			mockServiceConnectManager.EXPECT().GetCapabilitiesForAppnetInterfaceVersion("").AnyTimes()
+			mockServiceConnectManager.EXPECT().SetECSClient(gomock.Any(), gomock.Any()).AnyTimes()
+
+			mockDaemonManager.EXPECT().IsLoaded(gomock.Any()).Return(true, nil).AnyTimes()
+			mockDaemonManager.EXPECT().LoadImage(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+
+			dockerClient.EXPECT().SupportedVersions().Return(apiVersions).AnyTimes()
+
+			gomock.InOrder(
+				client.EXPECT().GetHostResources().Return(testHostResource, nil),
+				mockCredentialsProvider.EXPECT().Retrieve(gomock.Any()).Return(awsv2.Credentials{}, nil),
+				mockMobyPlugins.EXPECT().Scan().AnyTimes().Return([]string{}, nil),
+				dockerClient.EXPECT().ListPluginsWithFilters(gomock.Any(), gomock.Any(), gomock.Any(),
+					gomock.Any()).AnyTimes().Return([]string{}, nil),
+
+				client.EXPECT().
+					RegisterContainerInstance(
+						gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
+						gomock.Any(), gomock.Any(),
+					).
+					Return("", "", tc.regError),
+			)
+
+			cfg := getTestConfig()
+			ctx, cancel := context.WithCancel(context.TODO())
+			defer cancel()
+
+			agent := &ecsAgent{
+				ctx:               ctx,
+				ec2MetadataClient: mockEC2Metadata,
+				cfg:               &cfg,
+				pauseLoader:       mockPauseLoader,
+				credentialsCache:  awsv2.NewCredentialsCache(mockCredentialsProvider),
+				dockerClient:      dockerClient,
+				mobyPlugins:       mockMobyPlugins,
+				terminationHandler: func(
+					taskEngineState dockerstate.TaskEngineState,
+					dataClient data.Client,
+					taskEngine engine.TaskEngine,
+					cancel context.CancelFunc,
+				) {
+				},
+				serviceconnectManager: mockServiceConnectManager,
+				daemonManagers:        mockDaemonManagers,
+			}
+
+			exitCode := agent.doStart(
+				eventstream.NewEventStream("events", ctx),
+				credentialsManager,
+				state,
+				imageManager,
+				client,
+				execCmdMgr,
+			)
+
+			assert.Equal(t, tc.exitCode, exitCode)
+		})
+	}
+}
+
 func TestMergeTags(t *testing.T) {
 	ec2Key := "ec2Key"
 	ec2Value := "ec2Value"
diff --git a/agent/app/errors.go b/agent/app/errors.go
index 06b02349b84..66545904bc8 100644
--- a/agent/app/errors.go
+++ b/agent/app/errors.go
@@ -13,15 +13,16 @@
 
 package app
 
-// transientError represents a transient error when executing the ECS Agent
-type transientError struct {
+type terminalError struct {
 	error
 }
 
-// isTransient returns true if the error is transient
-func isTransient(err error) bool {
-	_, ok := err.(transientError)
-	return ok
+// isTerminal returns true if the error is already wrapped as an unrecoverable condition
+// which will allow agent to exit terminally.
+func isTerminal(err error) bool {
+	// Check if the error is already wrapped as a terminalError
+	_, terminal := err.(terminalError)
+	return terminal
 }
 
 // clusterMismatchError represents a mismatch in cluster name between the