diff --git a/internal/commands/root_test.go b/internal/commands/root_test.go index 758756070..7006b0e06 100644 --- a/internal/commands/root_test.go +++ b/internal/commands/root_test.go @@ -206,6 +206,7 @@ func assertError(t *testing.T, err error, expectedMessage string) { func clearFlags() { mock.Flags = wrappers.FeatureFlagsResponseModel{} mock.Flag = wrappers.FeatureFlagResponseModel{} + mock.FFErr = nil wrappers.ClearCache() } diff --git a/internal/commands/scan.go b/internal/commands/scan.go index b6adc9b23..89381d910 100644 --- a/internal/commands/scan.go +++ b/internal/commands/scan.go @@ -1132,7 +1132,7 @@ func validateScanTypes(cmd *cobra.Command, jwtWrapper wrappers.JWTWrapper, featu var scanTypes []string var SCSScanTypes []string - containerEngineCLIEnabled, _ := featureFlagsWrapper.GetSpecificFlag(wrappers.ContainerEngineCLIEnabled) + runContainerEngineCLI := isContainersEngineEnabled(featureFlagsWrapper) allowedEngines, err := jwtWrapper.GetAllowedEngines(featureFlagsWrapper) if err != nil { err = errors.Errorf("Error validating scan types: %v", err) @@ -1149,7 +1149,7 @@ func validateScanTypes(cmd *cobra.Command, jwtWrapper wrappers.JWTWrapper, featu scanTypes = strings.Split(userScanTypes, ",") for _, scanType := range scanTypes { - if !allowedEngines[scanType] || (scanType == commonParams.ContainersType && !(containerEngineCLIEnabled.Status)) { + if !allowedEngines[scanType] || (scanType == commonParams.ContainersType && !(runContainerEngineCLI)) { keys := reflect.ValueOf(allowedEngines).MapKeys() err = errors.Errorf(engineNotAllowed, scanType, scanType, keys) return err @@ -1165,7 +1165,7 @@ func validateScanTypes(cmd *cobra.Command, jwtWrapper wrappers.JWTWrapper, featu } else { for k := range allowedEngines { - if k == commonParams.ContainersType && !(containerEngineCLIEnabled.Status) { + if k == commonParams.ContainersType && !(runContainerEngineCLI) { continue } scanTypes = append(scanTypes, k) @@ -1178,6 +1178,16 @@ func validateScanTypes(cmd *cobra.Command, jwtWrapper wrappers.JWTWrapper, featu return nil } +func isContainersEngineEnabled(featureFlagsWrapper wrappers.FeatureFlagsWrapper) bool { + containerEngineCLIEnabled, err := featureFlagsWrapper.GetSpecificFlag(wrappers.ContainerEngineCLIEnabled) + if err != nil { + logger.PrintfIfVerbose("Failed to fetch CONTAINER_ENGINE_CLI_ENABLED FF, defaulting to `false`. Error: %s", err) + return false + } + + return containerEngineCLIEnabled.Status +} + func scanTypeEnabled(scanType string) bool { scanTypes := strings.Split(actualScanTypes, ",") for _, a := range scanTypes { diff --git a/internal/commands/scan_test.go b/internal/commands/scan_test.go index 86e1f9eff..dc7dbeb37 100644 --- a/internal/commands/scan_test.go +++ b/internal/commands/scan_test.go @@ -1958,3 +1958,22 @@ func TestValidateScanTypes(t *testing.T) { }) } } + +func TestIsContainersEngineEnabled_FlagEnabled(t *testing.T) { + clearFlags() + mock.Flag = wrappers.FeatureFlagResponseModel{Name: wrappers.ContainerEngineCLIEnabled, Status: true} + mock.FFErr = nil + + result := isContainersEngineEnabled(mock.FeatureFlagsMockWrapper{}) + assert.Assert(t, result, "expected result to be true") +} + +func TestIsContainersEngineEnabled_FlagRetrievalFails(t *testing.T) { + clearFlags() + mock.Flag = wrappers.FeatureFlagResponseModel{Name: wrappers.ContainerEngineCLIEnabled, Status: false} + mock.FFErr = errors.New("something went wrong while fetching ff") + + result := isContainersEngineEnabled(mock.FeatureFlagsMockWrapper{}) + + assert.Assert(t, !result, "expected result to be false") +} diff --git a/internal/wrappers/mock/feature-flags-mock.go b/internal/wrappers/mock/feature-flags-mock.go index c4932916d..95f62fccb 100644 --- a/internal/wrappers/mock/feature-flags-mock.go +++ b/internal/wrappers/mock/feature-flags-mock.go @@ -8,6 +8,7 @@ import ( var Flags wrappers.FeatureFlagsResponseModel var Flag wrappers.FeatureFlagResponseModel +var FFErr error type FeatureFlagsMockWrapper struct { } @@ -22,5 +23,8 @@ func (f FeatureFlagsMockWrapper) GetAll() (*wrappers.FeatureFlagsResponseModel, func (f FeatureFlagsMockWrapper) GetSpecificFlag(specificFlag string) (*wrappers.FeatureFlagResponseModel, error) { fmt.Println("Called GetSpecificFlag in FeatureFlagsMockWrapper with flag:", specificFlag) + if FFErr != nil { + return nil, FFErr + } return &Flag, nil }