From ef0ea0cc8d0eb606d18e186fdc68b3c8c10048e8 Mon Sep 17 00:00:00 2001 From: Erik Weber Date: Thu, 19 Sep 2024 13:36:37 +0200 Subject: [PATCH] Add --scan-all clusters option to ECS snapshot command --- cmd/kosli/root.go | 3 ++- cmd/kosli/snapshotECS.go | 47 +++++++++++++++++++++++++++-------- cmd/kosli/snapshotECS_test.go | 11 ++++++-- internal/aws/aws.go | 18 ++++++++++++++ 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/cmd/kosli/root.go b/cmd/kosli/root.go index 1fb1e2ee6..06090c198 100644 --- a/cmd/kosli/root.go +++ b/cmd/kosli/root.go @@ -157,8 +157,9 @@ The service principal needs to have the following permissions: resultsDirFlag = "[defaulted] The path to a directory with JUnit test results. By default, the directory will be uploaded to Kosli's evidence vault." snykJsonResultsFileFlag = "The path to Snyk SARIF or JSON scan results file from 'snyk test' and 'snyk container test'. By default, the Snyk results will be uploaded to Kosli's evidence vault." snykSarifResultsFileFlag = "The path to Snyk scan SARIF results file from 'snyk test' and 'snyk container test'. By default, the Snyk results will be uploaded to Kosli's evidence vault." - ecsClusterFlag = "The name of the ECS cluster." + ecsClusterFlag = "[exclusive] The name of the ECS cluster." ecsServiceFlag = "[optional] The name of the ECS service." + ecsScanAllFlag = "[exclusive] Scan all ECS clusters." kubeconfigFlag = "[defaulted] The kubeconfig path for the target cluster." namespaceFlag = "[conditional] The comma separated list of namespaces regex patterns to report artifacts info from. Can't be used together with --exclude-namespace." excludeNamespaceFlag = "[conditional] The comma separated list of namespaces regex patterns NOT to report artifacts info from. Can't be used together with --namespace." diff --git a/cmd/kosli/snapshotECS.go b/cmd/kosli/snapshotECS.go index 4c3716b0b..1330f8711 100644 --- a/cmd/kosli/snapshotECS.go +++ b/cmd/kosli/snapshotECS.go @@ -51,6 +51,7 @@ type snapshotECSOptions struct { cluster string serviceName string awsStaticCreds *aws.AWSStaticCreds + scanAll bool } func newSnapshotECSCmd(out io.Writer) *cobra.Command { @@ -67,6 +68,11 @@ func newSnapshotECSCmd(out io.Writer) *cobra.Command { if err != nil { return ErrorBeforePrintingUsage(cmd, err.Error()) } + + err = MuXRequiredFlags(cmd, []string{"cluster", "scan-all"}, true) + if err != nil { + return err + } return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -76,28 +82,47 @@ func newSnapshotECSCmd(out io.Writer) *cobra.Command { cmd.Flags().StringVarP(&o.cluster, "cluster", "C", "", ecsClusterFlag) cmd.Flags().StringVarP(&o.serviceName, "service-name", "s", "", ecsServiceFlag) + cmd.Flags().BoolVarP(&o.scanAll, "scan-all", "A", false, ecsScanAllFlag) addAWSAuthFlags(cmd, o.awsStaticCreds) addDryRunFlag(cmd) - err := RequireFlags(cmd, []string{"cluster"}) - if err != nil { - logger.Error("failed to configure required flags: %v", err) - } - return cmd } func (o *snapshotECSOptions) run(args []string) error { envName := args[0] url := fmt.Sprintf("%s/api/v2/environments/%s/%s/report/ECS", global.Host, global.Org, envName) - - tasksData, err := o.awsStaticCreds.GetEcsTasksData(o.cluster, o.serviceName) - if err != nil { - return err + logger.Debug("ECS Snapshot parameters: scan-all: %t, cluster: %s", o.scanAll, o.cluster) + + tasksDataList := []*aws.EcsTaskData{} + + if o.scanAll { + clusters, err := o.awsStaticCreds.GetECSClusters() + if err != nil { + logger.Error("Failed to get ECS clusters: %v", err) + return err + } + logger.Debug("Attempting to find ECS clusters") + for _, cluster := range clusters { + logger.Debug("Found ECS cluster with ARN: %s", cluster) + tasksData, err := o.awsStaticCreds.GetEcsTasksData(cluster, o.serviceName) + if err != nil { + return err + } + // append to EcsTaskDataList + tasksDataList = append(tasksDataList, tasksData...) + + } + } else { + tasksData, err := o.awsStaticCreds.GetEcsTasksData(o.cluster, o.serviceName) + if err != nil { + return err + } + tasksDataList = append(tasksDataList, tasksData...) } payload := &aws.EcsEnvRequest{ - Artifacts: tasksData, + Artifacts: tasksDataList, } reqParams := &requests.RequestParams{ @@ -107,7 +132,7 @@ func (o *snapshotECSOptions) run(args []string) error { DryRun: global.DryRun, Password: global.ApiToken, } - _, err = kosliClient.Do(reqParams) + _, err := kosliClient.Do(reqParams) if err == nil && !global.DryRun { logger.Info("[%d] containers were reported to environment %s", len(payload.Artifacts), envName) } diff --git a/cmd/kosli/snapshotECS_test.go b/cmd/kosli/snapshotECS_test.go index 02ded4279..98734df69 100644 --- a/cmd/kosli/snapshotECS_test.go +++ b/cmd/kosli/snapshotECS_test.go @@ -37,9 +37,16 @@ func (suite *SnapshotECSTestSuite) TestSnapshotECSCmd() { tests := []cmdTestCase{ { wantError: true, - name: "snapshot ECS fails if --cluster is missing", + name: "snapshot ECS fails if --cluster and --scan-all is missing", cmd: fmt.Sprintf(`snapshot ecs %s %s`, suite.envName, suite.defaultKosliArguments), - golden: "Error: required flag(s) \"cluster\" not set\n", + golden: "Error: at least one of --cluster, --scan-all is required\n", + // golden: "Error: required flag(s) \"cluster\" not set\n", + }, + { + wantError: true, + name: "snapshot ECS fails if both --cluster and --scan-all is set", + cmd: fmt.Sprintf(`snapshot ecs %s %s --cluster sss --scan-all`, suite.envName, suite.defaultKosliArguments), + golden: "Error: only one of --cluster, --scan-all is allowed\n", }, { wantError: true, diff --git a/internal/aws/aws.go b/internal/aws/aws.go index 6f66cbfb7..8e60cd57e 100644 --- a/internal/aws/aws.go +++ b/internal/aws/aws.go @@ -125,6 +125,24 @@ func (staticCreds *AWSStaticCreds) NewECSClient() (*ecs.Client, error) { return ecs.NewFromConfig(cfg), nil } +func (staticCreds *AWSStaticCreds) GetECSClusters() ([]string, error) { + ecsClient, err := staticCreds.NewECSClient() + if err != nil { + return []string{}, err + } + clusters := []string{} + params := &ecs.ListClustersInput{} + paginator := ecs.NewListClustersPaginator(ecsClient, params) + for paginator.HasMorePages() { + output, err := paginator.NextPage(context.Background()) + if err != nil { + return clusters, err + } + clusters = append(clusters, output.ClusterArns...) + } + return clusters, nil +} + // getAllLambdaFuncs fetches all lambda functions recursively (50 at a time) and returns a list of FunctionConfiguration func getAllLambdaFuncs(client *lambda.Client, nextMarker *string, allFunctions *[]types.FunctionConfiguration) (*[]types.FunctionConfiguration, error) { params := &lambda.ListFunctionsInput{}