diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index d53e202..eec206a 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -20,6 +20,7 @@ import ( "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" + "github.com/pkg/errors" ) // ECS command context structure @@ -676,6 +677,7 @@ func (execCtx *executionContext) retrieveLogs(ctx context.Context) error { } func (e *commandContext) Cleanup(ctx context.Context, jobID string, c *cluster.Cluster) error { + cleanupMethod.CountRequest() // Resolve cluster context to get cluster name clusterContext := &clusterContext{} @@ -691,8 +693,10 @@ func (e *commandContext) Cleanup(ctx context.Context, jobID string, c *cluster.C ecsClient := ecs.NewFromConfig(cfg) // List all tasks started by this job - var allTaskARNs []string - for taskNum := 0; taskNum < e.TaskCount; taskNum++ { + maxTaskCount := clusterContext.MaxTaskCount + + taskARNs := make([]string, 0) + for taskNum := 0; taskNum < maxTaskCount; taskNum++ { startedByValue := fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, taskNum) listTasksOutput, err := ecsClient.ListTasks(ctx, &ecs.ListTasksInput{ @@ -703,44 +707,32 @@ func (e *commandContext) Cleanup(ctx context.Context, jobID string, c *cluster.C cleanupMethod.CountError("list_tasks") return err } - allTaskARNs = append(allTaskARNs, listTasksOutput.TaskArns...) + + taskARNs = append(taskARNs, listTasksOutput.TaskArns...) + + time.Sleep(100 * time.Millisecond) // prevent API throttling } - if len(allTaskARNs) == 0 { + if len(taskARNs) == 0 { // No tasks found, nothing to clean up cleanupMethod.CountSuccess("no_tasks_found") return nil } - // Bulk describe all tasks to check their LastStatus - describeOutput, err := ecsClient.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: aws.String(clusterContext.ClusterName), - Tasks: allTaskARNs, - }) - if err != nil { - cleanupMethod.CountError("describe_tasks") - return err - } - - // Stop all tasks where LastStatus != STOPPED or SUCCEEDED - for _, task := range describeOutput.Tasks { - // Skip tasks that are already stopped - if aws.ToString(task.LastStatus) == "STOPPED" || aws.ToString(task.LastStatus) == "SUCCEEDED" { - continue - } - + // Stop all tasks we found. StopTask is safe to call even if the task is already stopping/stopped. + for _, taskARN := range taskARNs { stopTaskInput := &ecs.StopTaskInput{ Cluster: aws.String(clusterContext.ClusterName), - Task: task.TaskArn, + Task: aws.String(taskARN), Reason: aws.String(errJobTerminated), } - _, err := ecsClient.StopTask(ctx, stopTaskInput) - if err != nil { + if _, err := ecsClient.StopTask(ctx, stopTaskInput); err != nil { // Log error but continue stopping other tasks - cleanupMethod.LogAndCountError(err, fmt.Sprintf("failed to stop task %s", aws.ToString(task.TaskArn))) - continue + err = errors.Wrapf(err, "failed to stop task %s", taskARN) + cleanupMethod.LogAndCountError(err, "stop_task") } - cleanupMethod.CountSuccess("stop_task") + + time.Sleep(100 * time.Millisecond) // prevent API throttling } cleanupMethod.CountSuccess() return nil