From 73a51ee08d959efe2a11f33a5d4b0058b25d5356 Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Wed, 7 Jun 2023 09:29:06 -0400 Subject: [PATCH 1/8] update predeploy to restart old VMSS when service secrets rotated --- pkg/deploy/predeploy.go | 98 ++++++++++++++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 17 deletions(-) diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 4a3097a783b..42c504bab46 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -14,6 +14,7 @@ import ( "strings" "time" + mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute" azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault" mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features" "github.com/Azure/go-autorest/autorest/azure" @@ -352,6 +353,7 @@ func (d *deployer) deployPreDeploy(ctx context.Context, resourceGroupName, deplo } func (d *deployer) configureServiceSecrets(ctx context.Context) error { + isRotated := false for _, s := range []struct { kv keyvault.Manager secretName string @@ -361,7 +363,8 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { {d.serviceKeyvault, env.FrontendEncryptionSecretV2Name, 64}, {d.portalKeyvault, env.PortalServerSessionKeySecretName, 32}, } { - err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len) + isNew, err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len) + isRotated = isNew || isRotated if err != nil { return err } @@ -376,26 +379,43 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { {d.serviceKeyvault, env.EncryptionSecretName, 32}, {d.serviceKeyvault, env.FrontendEncryptionSecretName, 32}, } { - err := d.ensureSecret(ctx, s.kv, s.secretName, s.len) + isNew, err := d.ensureSecret(ctx, s.kv, s.secretName, s.len) + isRotated = isNew || isRotated if err != nil { return err } } - return d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName) + isNew, err := d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName) + isRotated = isNew || isRotated + if err != nil { + return err + } + + if isRotated { + err = d.restartOldScalesets(ctx, "systemctl restart aro-gateway", d.config.GatewayResourceGroupName) + if err != nil { + return err + } + err = d.restartOldScalesets(ctx, "systemctl restart aro-rp", d.config.RPResourceGroupName) + if err != nil { + return err + } + } + return nil } -func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { +func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { latestVersion, err := kv.GetSecret(ctx, secretName) if err != nil { - return err + return false, err } updatedTime := time.Unix(0, latestVersion.Attributes.Created.Duration().Nanoseconds()).Add(rotateSecretAfter) @@ -403,27 +423,27 @@ func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manage // do not create a secret if rotateSecretAfter time has // not elapsed since the secret version's creation timestamp if time.Now().Before(updatedTime) { - return nil + return false, nil } } } - return d.createSecret(ctx, kv, secretName, len) + return true, d.createSecret(ctx, kv, secretName, len) } -func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { +func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { - return nil + return false, nil } } - return d.createSecret(ctx, kv, secretName, len) + return true, d.createSecret(ctx, kv, secretName, len) } func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { @@ -439,25 +459,69 @@ func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secret }) } -func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) error { +func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { - return nil + return false, nil } } key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return false, err } d.log.Infof("setting %s", secretName) - return kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{ + return true, kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{ Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))), }) } + +func (d *deployer) restartOldScalesets(ctx context.Context, script string, resourceGroupName string) error { + d.log.Print("restarting old scalesets") + scalesets, err := d.vmss.List(ctx, resourceGroupName) + if err != nil { + return err + } + + for _, vmss := range scalesets { + err = d.restartOldScaleset(ctx, *vmss.Name, script, resourceGroupName) + if err != nil { + return err + } + } + + return nil +} + +func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, script string, resourceGroupName string) error { + scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "") + if err != nil { + return err + } + + d.log.Printf("restarting scaleset %s", vmssName) + errors := make(chan error, len(scalesetVMs)) + for _, vm := range scalesetVMs { + go func(id string) { + errors <- d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, id, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{script}, + }) + }(*vm.InstanceID) + } + + d.log.Print("waiting for instances to restart") + for range scalesetVMs { + err := <-errors + if err != nil { + return err + } + } + return nil +} From a03041136fc4d400cbe174cf461e77e759153094 Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Wed, 14 Jun 2023 17:29:05 -0400 Subject: [PATCH 2/8] update scalesetVMSS conditions check for restart at RP predeploy --- pkg/deploy/predeploy.go | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 42c504bab46..50ec43596b2 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -27,9 +27,13 @@ import ( "github.com/Azure/ARO-RP/pkg/util/keyvault" ) -// Rotate the secret on every deploy of the RP if the most recent -// secret is greater than 7 days old -const rotateSecretAfter = time.Hour * 168 +const ( + // Rotate the secret on every deploy of the RP if the most recent + // secret is greater than 7 days old + rotateSecretAfter = time.Hour * 168 + rpRestartScript = "systemctl restart aro-rp" + gatewayRestartScript = "systemctl restart aro-gateway" +) // PreDeploy deploys managed identity, NSGs and keyvaults, needed for main // deployment @@ -393,11 +397,11 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { } if isRotated { - err = d.restartOldScalesets(ctx, "systemctl restart aro-gateway", d.config.GatewayResourceGroupName) + err = d.restartOldScalesets(ctx, d.config.GatewayResourceGroupName) if err != nil { return err } - err = d.restartOldScalesets(ctx, "systemctl restart aro-rp", d.config.RPResourceGroupName) + err = d.restartOldScalesets(ctx, d.config.RPResourceGroupName) if err != nil { return err } @@ -482,7 +486,7 @@ func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, sec }) } -func (d *deployer) restartOldScalesets(ctx context.Context, script string, resourceGroupName string) error { +func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName string) error { d.log.Print("restarting old scalesets") scalesets, err := d.vmss.List(ctx, resourceGroupName) if err != nil { @@ -490,7 +494,7 @@ func (d *deployer) restartOldScalesets(ctx context.Context, script string, resou } for _, vmss := range scalesets { - err = d.restartOldScaleset(ctx, *vmss.Name, script, resourceGroupName) + err = d.restartOldScaleset(ctx, *vmss.Name, resourceGroupName) if err != nil { return err } @@ -499,7 +503,17 @@ func (d *deployer) restartOldScalesets(ctx context.Context, script string, resou return nil } -func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, script string, resourceGroupName string) error { +func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, resourceGroupName string) error { + var restartScript string + switch { + case strings.HasPrefix(vmssName, gatewayVMSSPrefix): + restartScript = gatewayRestartScript + case strings.HasPrefix(vmssName, rpVMSSPrefix): + restartScript = rpRestartScript + default: + return nil + } + scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "") if err != nil { return err @@ -511,7 +525,7 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, scri go func(id string) { errors <- d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, id, mgmtcompute.RunCommandInput{ CommandID: to.StringPtr("RunShellScript"), - Script: &[]string{script}, + Script: &[]string{restartScript}, }) }(*vm.InstanceID) } From 0cf0937a02455d6366ca804f1222e20aa9a2c599 Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Tue, 20 Jun 2023 13:10:50 -0400 Subject: [PATCH 3/8] add vmss health check after vmss restart --- pkg/deploy/predeploy.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 50ec43596b2..29aa89e7206 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -505,11 +505,14 @@ func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName st func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, resourceGroupName string) error { var restartScript string + var waitForReadiness func(ctx context.Context, vmssName string) error switch { case strings.HasPrefix(vmssName, gatewayVMSSPrefix): restartScript = gatewayRestartScript + waitForReadiness = d.gatewayWaitForReadiness case strings.HasPrefix(vmssName, rpVMSSPrefix): restartScript = rpRestartScript + waitForReadiness = d.rpWaitForReadiness default: return nil } @@ -537,5 +540,14 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso return err } } + + // wait for load balancer probe to change the health status + time.Sleep(30 * time.Second) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour) + defer cancel() + err = waitForReadiness(timeoutCtx, vmssName) + if err != nil { + return err + } return nil } From c563d48e41c6b7892e0523bbfd9f040d61afad9b Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Thu, 22 Jun 2023 09:10:28 -0400 Subject: [PATCH 4/8] nit changes related to logging --- pkg/deploy/predeploy.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 29aa89e7206..423d6336aa4 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -30,7 +30,7 @@ import ( const ( // Rotate the secret on every deploy of the RP if the most recent // secret is greater than 7 days old - rotateSecretAfter = time.Hour * 168 + rotateSecretAfter = time.Hour * 24 * 7 rpRestartScript = "systemctl restart aro-rp" gatewayRestartScript = "systemctl restart aro-gateway" ) @@ -533,7 +533,7 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso }(*vm.InstanceID) } - d.log.Print("waiting for instances to restart") + d.log.Print("waiting for restart script to complete") for range scalesetVMs { err := <-errors if err != nil { From 5f9b2664c58fa781cde49a80b907b8b660cf5124 Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Wed, 28 Jun 2023 17:13:24 -0400 Subject: [PATCH 5/8] remove concurrent rp service restarts --- pkg/deploy/predeploy.go | 55 ++++++++++++++++++++--------------- pkg/deploy/upgrade_gateway.go | 5 +--- pkg/deploy/upgrade_rp.go | 5 +--- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 423d6336aa4..e15a1dbf1b5 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -19,6 +19,7 @@ import ( mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/to" + "k8s.io/apimachinery/pkg/util/wait" "github.com/Azure/ARO-RP/pkg/deploy/assets" "github.com/Azure/ARO-RP/pkg/deploy/generator" @@ -487,7 +488,6 @@ func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, sec } func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName string) error { - d.log.Print("restarting old scalesets") scalesets, err := d.vmss.List(ctx, resourceGroupName) if err != nil { return err @@ -505,14 +505,11 @@ func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName st func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, resourceGroupName string) error { var restartScript string - var waitForReadiness func(ctx context.Context, vmssName string) error switch { case strings.HasPrefix(vmssName, gatewayVMSSPrefix): restartScript = gatewayRestartScript - waitForReadiness = d.gatewayWaitForReadiness case strings.HasPrefix(vmssName, rpVMSSPrefix): restartScript = rpRestartScript - waitForReadiness = d.rpWaitForReadiness default: return nil } @@ -522,32 +519,42 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso return err } - d.log.Printf("restarting scaleset %s", vmssName) - errors := make(chan error, len(scalesetVMs)) for _, vm := range scalesetVMs { - go func(id string) { - errors <- d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, id, mgmtcompute.RunCommandInput{ - CommandID: to.StringPtr("RunShellScript"), - Script: &[]string{restartScript}, - }) - }(*vm.InstanceID) - } + d.log.Print("waiting for restart script to complete on older vmss %s, instance %s", vmssName, *vm.InstanceID) + err = d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, *vm.InstanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{restartScript}, + }) - d.log.Print("waiting for restart script to complete") - for range scalesetVMs { - err := <-errors if err != nil { return err } - } - // wait for load balancer probe to change the health status - time.Sleep(30 * time.Second) - timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour) - defer cancel() - err = waitForReadiness(timeoutCtx, vmssName) - if err != nil { - return err + // wait for load balancer probe to change the vm health status + time.Sleep(30 * time.Second) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour) + defer cancel() + err = d.waitForReadiness(timeoutCtx, vmssName, *vm.InstanceID) + if err != nil { + return err + } } + return nil } + +func (d *deployer) waitForReadiness(ctx context.Context, vmssName string, vmInstanceID string) error { + return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { + return d.isVMInstanceHealthy(ctx, vmssName, vmInstanceID), nil + }, ctx.Done()) +} + +func (d *deployer) isVMInstanceHealthy(ctx context.Context, vmssName string, vmInstanceID string) bool { + r, err := d.vmssvms.GetInstanceView(ctx, d.config.RPResourceGroupName, vmssName, vmInstanceID) + instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" + if err != nil || instanceUnhealthy { + d.log.Printf("instance %s status %s", vmInstanceID, *r.VMHealth.Status.Code) + return false + } + return true +} diff --git a/pkg/deploy/upgrade_gateway.go b/pkg/deploy/upgrade_gateway.go index a5eedb87d07..3af14fb1556 100644 --- a/pkg/deploy/upgrade_gateway.go +++ b/pkg/deploy/upgrade_gateway.go @@ -40,10 +40,7 @@ func (d *deployer) gatewayWaitForReadiness(ctx context.Context, vmssName string) d.log.Printf("waiting for %s instances to be healthy", vmssName) return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { for _, vm := range scalesetVMs { - r, err := d.vmssvms.GetInstanceView(ctx, d.config.GatewayResourceGroupName, vmssName, *vm.InstanceID) - instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" - if err != nil || instanceUnhealthy { - d.log.Printf("instance %s status %s", *vm.InstanceID, *r.VMHealth.Status.Code) + if !d.isVMInstanceHealthy(ctx, vmssName, *vm.InstanceID) { return false, nil } } diff --git a/pkg/deploy/upgrade_rp.go b/pkg/deploy/upgrade_rp.go index 875d61c28c1..3bf970a432d 100644 --- a/pkg/deploy/upgrade_rp.go +++ b/pkg/deploy/upgrade_rp.go @@ -40,10 +40,7 @@ func (d *deployer) rpWaitForReadiness(ctx context.Context, vmssName string) erro d.log.Printf("waiting for %s instances to be healthy", vmssName) return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { for _, vm := range scalesetVMs { - r, err := d.vmssvms.GetInstanceView(ctx, d.config.RPResourceGroupName, vmssName, *vm.InstanceID) - instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" - if err != nil || instanceUnhealthy { - d.log.Printf("instance %s status %s", *vm.InstanceID, *r.VMHealth.Status.Code) + if !d.isVMInstanceHealthy(ctx, vmssName, *vm.InstanceID) { return false, nil } } From a1078f95f0f580dc985fc4546e8af9d78ea13b9e Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Tue, 18 Jul 2023 13:30:38 -0400 Subject: [PATCH 6/8] Add unit test cases for RP predeploy function --- pkg/deploy/predeploy.go | 27 +- pkg/deploy/predeploy_test.go | 1949 ++++++++++++++++++++ pkg/deploy/upgrade_gateway.go | 2 +- pkg/deploy/upgrade_rp.go | 2 +- pkg/util/azureclient/mgmt/msi/generate.go | 8 + pkg/util/mocks/azureclient/mgmt/msi/msi.go | 51 + 6 files changed, 2029 insertions(+), 10 deletions(-) create mode 100644 pkg/deploy/predeploy_test.go create mode 100644 pkg/util/azureclient/mgmt/msi/generate.go create mode 100644 pkg/util/mocks/azureclient/mgmt/msi/msi.go diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index e15a1dbf1b5..ff223f04fcf 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -10,6 +10,8 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" + "fmt" + "net/http" "path/filepath" "strings" "time" @@ -21,6 +23,7 @@ import ( "github.com/Azure/go-autorest/autorest/to" "k8s.io/apimachinery/pkg/util/wait" + "github.com/Azure/ARO-RP/pkg/api" "github.com/Azure/ARO-RP/pkg/deploy/assets" "github.com/Azure/ARO-RP/pkg/deploy/generator" "github.com/Azure/ARO-RP/pkg/env" @@ -511,7 +514,15 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso case strings.HasPrefix(vmssName, rpVMSSPrefix): restartScript = rpRestartScript default: - return nil + return &api.CloudError{ + StatusCode: http.StatusBadRequest, + CloudErrorBody: &api.CloudErrorBody{ + Code: api.CloudErrorCodeInvalidResource, + Message: fmt.Sprintf("provided vmss %s does not match RP or gateway prefix", + vmssName, + ), + }, + } } scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "") @@ -520,7 +531,7 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso } for _, vm := range scalesetVMs { - d.log.Print("waiting for restart script to complete on older vmss %s, instance %s", vmssName, *vm.InstanceID) + d.log.Printf("waiting for restart script to complete on older vmss %s, instance %s", vmssName, *vm.InstanceID) err = d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, *vm.InstanceID, mgmtcompute.RunCommandInput{ CommandID: to.StringPtr("RunShellScript"), Script: &[]string{restartScript}, @@ -534,7 +545,7 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso time.Sleep(30 * time.Second) timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour) defer cancel() - err = d.waitForReadiness(timeoutCtx, vmssName, *vm.InstanceID) + err = d.waitForReadiness(timeoutCtx, resourceGroupName, vmssName, *vm.InstanceID) if err != nil { return err } @@ -543,17 +554,17 @@ func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, reso return nil } -func (d *deployer) waitForReadiness(ctx context.Context, vmssName string, vmInstanceID string) error { +func (d *deployer) waitForReadiness(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) error { return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { - return d.isVMInstanceHealthy(ctx, vmssName, vmInstanceID), nil + return d.isVMInstanceHealthy(ctx, resourceGroupName, vmssName, vmInstanceID), nil }, ctx.Done()) } -func (d *deployer) isVMInstanceHealthy(ctx context.Context, vmssName string, vmInstanceID string) bool { - r, err := d.vmssvms.GetInstanceView(ctx, d.config.RPResourceGroupName, vmssName, vmInstanceID) +func (d *deployer) isVMInstanceHealthy(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) bool { + r, err := d.vmssvms.GetInstanceView(ctx, resourceGroupName, vmssName, vmInstanceID) instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" if err != nil || instanceUnhealthy { - d.log.Printf("instance %s status %s", vmInstanceID, *r.VMHealth.Status.Code) + d.log.Printf("instance %s is unhealthy", vmInstanceID) return false } return true diff --git a/pkg/deploy/predeploy_test.go b/pkg/deploy/predeploy_test.go new file mode 100644 index 00000000000..a02aa3c885e --- /dev/null +++ b/pkg/deploy/predeploy_test.go @@ -0,0 +1,1949 @@ +package deploy + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute" + azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault" + mgmtmsi "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" + mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + gofrsuuid "github.com/gofrs/uuid" + "github.com/golang/mock/gomock" + "github.com/sirupsen/logrus" + + "github.com/Azure/ARO-RP/pkg/deploy/generator" + "github.com/Azure/ARO-RP/pkg/env" + mock_compute "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/compute" + mock_features "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/features" + mock_msi "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/msi" + mock_keyvault "github.com/Azure/ARO-RP/pkg/util/mocks/keyvault" + utilerror "github.com/Azure/ARO-RP/test/util/error" +) + +func TestPreDeploy(t *testing.T) { + ctx := context.Background() + subscriptionRgName := "testRG-subscription" + globalRgName := "testRG-global" + rpRgName := "testRG-aro-rp" + gatewayRgName := "testRG-gwy" + location := "testLocation" + overrideLocation := "overrideTestLocation" + group := mgmtfeatures.ResourceGroup{ + Location: &location, + } + fakeMSIObjectId, _ := gofrsuuid.NewV4() + msi := mgmtmsi.Identity{ + UserAssignedIdentityProperties: &mgmtmsi.UserAssignedIdentityProperties{ + PrincipalID: &fakeMSIObjectId, + }, + } + deployment := mgmtfeatures.DeploymentExtended{} + partialSecretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: to.StringPtr(env.EncryptionSecretV2Name), + }, + { + ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), + }, + } + rpVMSSName := rpVMSSPrefix + "test" + nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) + newSecretBundle := azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{ + Created: &nowUnixTime, + }, + } + vmsss := []mgmtcompute.VirtualMachineScaleSet{ + { + Name: to.StringPtr(rpVMSSName), + }, + } + allSecretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: to.StringPtr(env.EncryptionSecretV2Name), + }, + { + ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), + }, + { + ID: to.StringPtr(env.PortalServerSessionKeySecretName), + }, + { + ID: to.StringPtr(env.EncryptionSecretName), + }, + { + ID: to.StringPtr(env.FrontendEncryptionSecretName), + }, + { + ID: to.StringPtr(env.PortalServerSSHKeySecretName), + }, + } + instanceID := "testID" + vms := []mgmtcompute.VirtualMachineScaleSetVM{ + { + InstanceID: to.StringPtr(instanceID), + }, + } + healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/healthy"), + }, + }, + } + + type mock func(*mock_features.MockDeploymentsClient, *mock_features.MockResourceGroupsClient, *mock_msi.MockUserAssignedIdentitiesClient, *mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient) + genericSubscriptionDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ).AnyTimes() + } + subscriptionDeploymentSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return(nil) + } + subscriptionRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, subscriptionRgName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + globalRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, globalRgName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + gatewayRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, gatewayRgName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + rpRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, rpRgName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + deploymentSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + } + subscriptionResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + rg.EXPECT().CreateOrUpdate(ctx, subscriptionRgName, gomock.Any()).Return( + group, + errors.New("generic error"), + ) + } + globalResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + rg.EXPECT().CreateOrUpdate(ctx, globalRgName, gomock.Any()).Return( + group, + errors.New("generic error"), + ) + } + rpResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + rg.EXPECT().CreateOrUpdate(ctx, rpRgName, gomock.Any()).Return( + group, + errors.New("generic error"), + ) + } + gatewayResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + rg.EXPECT().CreateOrUpdate(ctx, gatewayRgName, gomock.Any()).Return( + group, + errors.New("generic error"), + ) + } + resourceGroupDeploymentSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + rg.EXPECT().CreateOrUpdate(ctx, gomock.Any(), gomock.Any()).Return(group, nil) + } + rpMSIGetFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + m.EXPECT().Get(ctx, rpRgName, gomock.Any()).Return(msi, errors.New("generic error")) + } + rpMSIGetSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + m.EXPECT().Get(ctx, rpRgName, gomock.Any()).Return(msi, nil) + } + gatewayMSIGetFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + m.EXPECT().Get(ctx, gatewayRgName, gomock.Any()).Return(msi, errors.New("generic error")) + } + gatewayMSIGetSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + m.EXPECT().Get(ctx, gatewayRgName, gomock.Any()).Return(msi, nil) + } + getDeploymentFailedWithDeploymentNotFound := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + d.EXPECT().Get(ctx, gatewayRgName, gomock.Any()).Return(deployment, autorest.DetailedError{ + Original: &azure.RequestError{ + ServiceError: &azure.ServiceError{ + Code: "DeploymentNotFound", + Details: []map[string]interface{}{ + {}, + }, + }, + }, + }) + } + getSecretsFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecrets(ctx).Return( + partialSecretItems, errors.New("generic error"), + ) + } + getSecretsSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecrets(ctx).Return( + allSecretItems, nil, + ) + } + getNewSecretSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecret(ctx, gomock.Any()).Return( + newSecretBundle, nil, + ) + } + getPartialSecretsSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecrets(ctx).Return( + partialSecretItems, nil, + ) + } + setSecretSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( + nil, + ) + } + vmssListSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmss.EXPECT().List(ctx, gomock.Any()).Return( + vmsss, nil, + ) + } + vmssVMsListSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().List(ctx, gomock.Any(), gomock.Any(), "", "", "").Return( + vms, nil, + ) + } + restartSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().RunCommandAndWait(ctx, gomock.Any(), gomock.Any(), instanceID, gomock.Any()).Return(nil) + } + healthyInstanceView := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), gomock.Any(), gomock.Any(), instanceID).Return(healthyVMSS, nil) + } + + for _, tt := range []struct { + name string + location string + overrideLocation string + acrReplicaDisabled bool + subscriptionRgName string + globalResourceGroup string + rpResourceGroupName string + gatewayResourceGroupName string + mocks []mock + wantErr string + }{ + { + name: "don't continue if Global Subscription RBAC DeploymentFailed", + location: location, + mocks: []mock{ + genericSubscriptionDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment is Successful but SubscriptionResourceGroup creation fails", + location: location, + subscriptionRgName: subscriptionRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, subscriptionResourceGroupDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment is Successful but GlobalResourceGroup creation fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, globalResourceGroupDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment is Successful but RPResourceGroup creation fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, rpResourceGroupDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment is successful but GatewayResourceGroup creation fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, gatewayResourceGroupDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment & resource group creation is successful but rp-subscription template deployment fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, subscriptionRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment is successful but rp managed identity get fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment is successful but gateway managed identity get fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment, msi get is successful but rpglobal deployment get fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, globalRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment, msi get is successful but rpglobal deployment get fails", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, globalRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment, resource group creation, rp-subscription deployment, rpglobal deployment is successful but ACR Replication fails", + location: location, + overrideLocation: overrideLocation, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, globalRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if skipping ACR Replication due to no ACRLocationOverride but failing gateway predeploy", + location: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, gatewayRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if skipping ACR Replication due to same ACRLocationOverride as location but failing gateway predeploy", + location: location, + overrideLocation: location, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, gatewayRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue if skipping ACR Replication due to ACRReplicaDisabled but failing gateway predeploy", + location: location, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, gatewayRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "don't continue gateway predeploy is successful but rp predeploy failed", + location: location, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, deploymentSuccessful, rpRGDeploymentFailed, + }, + wantErr: "generic error", + }, + { + name: "get error for the configureServiceSecrets", + location: location, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, deploymentSuccessful, deploymentSuccessful, getSecretsFailed, + }, + wantErr: "generic error", + }, + { + name: "Everything is successful", + location: location, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + mocks: []mock{ + subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, deploymentSuccessful, deploymentSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + mockResourceGroups := mock_features.NewMockResourceGroupsClient(controller) + mockMSIs := mock_msi.NewMockUserAssignedIdentitiesClient(controller) + mockKV := mock_keyvault.NewMockManager(controller) + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetsClient(controller) + mockVMSSVM := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + globaldeployments: mockDeployments, + deployments: mockDeployments, + groups: mockResourceGroups, + globalgroups: mockResourceGroups, + userassignedidentities: mockMSIs, + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupLocation: &tt.location, + SubscriptionResourceGroupLocation: &tt.location, + SubscriptionResourceGroupName: &tt.subscriptionRgName, + GlobalResourceGroupName: &tt.globalResourceGroup, + ACRLocationOverride: &tt.overrideLocation, + ACRReplicaDisabled: &tt.acrReplicaDisabled, + }, + RPResourceGroupName: tt.rpResourceGroupName, + GatewayResourceGroupName: tt.gatewayResourceGroupName, + Location: tt.location, + }, + serviceKeyvault: mockKV, + portalKeyvault: mockKV, + vmss: mockVMSS, + vmssvms: mockVMSSVM, + } + + for _, m := range tt.mocks { + m(mockDeployments, mockResourceGroups, mockMSIs, mockKV, mockVMSS, mockVMSSVM) + } + + err := d.PreDeploy(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPGlobalSubscription(t *testing.T) { + ctx := context.Background() + location := "locationTest" + + type mock func(*mock_features.MockDeploymentsClient) + subscriptionDeploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ).AnyTimes() + } + subscriptionDeploymentFailedWithDeploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return( + &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{ + {}, + }, + }, + ) + } + subscriptionDeploymentSuccessful := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return(nil) + } + + for _, tt := range []struct { + name string + deploymentFileName string + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails with error other than DeploymentFailed", + mocks: []mock{subscriptionDeploymentFailed}, + wantErr: "generic error", + }, + { + name: "Don't continue if deployment fails with error DeploymentFailed five times", + mocks: []mock{subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed}, + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, + }, + { + name: "Pass successfully when deployment is successfulin second attempt", + mocks: []mock{subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentSuccessful}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupLocation: &location, + }, + Location: location, + }, + globaldeployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments) + } + + err := d.deployRPGlobalSubscription(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPSubscription(t *testing.T) { + ctx := context.Background() + location := "locationTest" + subscriptionRGName := "rgTest" + + type mock func(*mock_features.MockDeploymentsClient) + deploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, subscriptionRGName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, subscriptionRGName, gomock.Any(), gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + deploymentFileName string + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails", + mocks: []mock{deploymentFailed}, + wantErr: "generic error", + }, + { + name: "Pass successfully when deployment is successful", + mocks: []mock{deploymentSuccess}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + SubscriptionResourceGroupName: &subscriptionRGName, + }, + Location: location, + }, + deployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments) + } + + err := d.deployRPSubscription(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployManagedIdentity(t *testing.T) { + ctx := context.Background() + rgName := "rgTest" + existingFileName := generator.FileGatewayProductionPredeploy + deploymentName := strings.TrimSuffix(existingFileName, ".json") + notExistingFileName := "testFile" + + type mock func(*mock_features.MockDeploymentsClient) + deploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( + errors.New("generic error"), + ) + } + deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + deploymentFileName string + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment file does not exist", + deploymentFileName: notExistingFileName, + wantErr: "open " + notExistingFileName + ": file does not exist", + }, + { + name: "Don't continue if deployment fails", + deploymentFileName: existingFileName, + mocks: []mock{deploymentFailed}, + wantErr: "generic error", + }, + { + name: "Pass successfully when deployment is successful", + deploymentFileName: existingFileName, + mocks: []mock{deploymentSuccess}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{}, + }, + deployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments) + } + + err := d.deployManagedIdentity(ctx, rgName, tt.deploymentFileName) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPGlobal(t *testing.T) { + ctx := context.Background() + location := "locationTest" + globalRGName := "globalRGTest" + rpSPID := "rpSPIDTest" + gwySPID := "gwySPIDTest" + + type mock func(*mock_features.MockDeploymentsClient) + deploymentFailedWithGenericError := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + deploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( + &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{ + {}, + }, + }, + ) + } + deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails with error other than DeploymentFailed", + mocks: []mock{deploymentFailedWithGenericError}, + wantErr: "generic error", + }, + { + name: "Don't continue if deployment fails with DeploymentFailed error twice", + mocks: []mock{deploymentFailed, deploymentFailed}, + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, + }, + { + name: "Pass successfully when deployment is successful in second attempt", + mocks: []mock{deploymentFailed, deploymentSuccess}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupName: to.StringPtr(globalRGName), + }, + Location: location, + }, + globaldeployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments) + } + + err := d.deployRPGlobal(ctx, rpSPID, gwySPID) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPGlobalACRReplication(t *testing.T) { + ctx := context.Background() + globalRGName := "globalRGTest" + location := "testLocation" + + type mock func(*mock_features.MockDeploymentsClient) + deploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails", + mocks: []mock{deploymentFailed}, + wantErr: "generic error", + }, + { + name: "Pass when deployment is successful", + mocks: []mock{deploymentSuccess}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupName: to.StringPtr(globalRGName), + }, + Location: location, + }, + globaldeployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments) + } + + err := d.deployRPGlobalACRReplication(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployPreDeploy(t *testing.T) { + ctx := context.Background() + rgName := "testRG" + gwyRGName := "testGwyRG" + existingFileName := generator.FileGatewayProductionPredeploy + deploymentName := strings.TrimSuffix(existingFileName, ".json") + notExistingFileName := "testFile" + spIDName := "testSPIDName" + spID := "testSPID" + + type mock func(*mock_features.MockDeploymentsClient) + deploymentFailed := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( + errors.New("generic error"), + ) + } + deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { + d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + resourceGroupName string + deploymentFileName string + spIDName string + spID string + isCreate bool + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment file does not exist", + resourceGroupName: rgName, + deploymentFileName: notExistingFileName, + spIDName: spIDName, + spID: spID, + wantErr: "open " + notExistingFileName + ": file does not exist", + }, + { + name: "Don't continue if deployment fails", + resourceGroupName: rgName, + deploymentFileName: existingFileName, + spIDName: spIDName, + spID: spID, + mocks: []mock{deploymentFailed}, + wantErr: "generic error", + }, + { + name: "Pass when deployment is successful", + resourceGroupName: rgName, + deploymentFileName: existingFileName, + spIDName: spIDName, + spID: spID, + mocks: []mock{deploymentSuccess}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{}, + GatewayResourceGroupName: gwyRGName, + }, + deployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments) + } + + err := d.deployPreDeploy(ctx, tt.resourceGroupName, tt.deploymentFileName, tt.spIDName, tt.spID, tt.isCreate) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestConfigureServiceSecrets(t *testing.T) { + ctx := context.Background() + rpVMSSName := rpVMSSPrefix + "test" + rgName := "rgTest" + nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) + newSecretBundle := azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{ + Created: &nowUnixTime, + }, + } + vmsss := []mgmtcompute.VirtualMachineScaleSet{ + { + Name: to.StringPtr(rpVMSSName), + }, + } + allSecretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: to.StringPtr(env.EncryptionSecretV2Name), + }, + { + ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), + }, + { + ID: to.StringPtr(env.PortalServerSessionKeySecretName), + }, + { + ID: to.StringPtr(env.EncryptionSecretName), + }, + { + ID: to.StringPtr(env.FrontendEncryptionSecretName), + }, + { + ID: to.StringPtr(env.PortalServerSSHKeySecretName), + }, + } + partialSecretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: to.StringPtr(env.EncryptionSecretV2Name), + }, + { + ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), + }, + } + instanceID := "testID" + vms := []mgmtcompute.VirtualMachineScaleSetVM{ + { + InstanceID: to.StringPtr(instanceID), + }, + } + healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/healthy"), + }, + }, + } + + type mock func(*mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient) + getSecretsFailed := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecrets(ctx).Return( + allSecretItems, errors.New("generic error"), + ) + } + getSecretsSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecrets(ctx).Return( + allSecretItems, nil, + ) + } + getNewSecretSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecret(ctx, gomock.Any()).Return( + newSecretBundle, nil, + ) + } + getPartialSecretsSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().GetSecrets(ctx).Return( + partialSecretItems, nil, + ) + } + setSecretSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( + nil, + ) + } + listVMSSFailed := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmss.EXPECT().List(ctx, gomock.Any()).Return( + vmsss, errors.New("VM List Failed"), + ) + } + vmssListSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmss.EXPECT().List(ctx, gomock.Any()).Return( + vmsss, nil, + ) + } + vmssVMsListSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().List(ctx, gomock.Any(), gomock.Any(), "", "", "").Return( + vms, nil, + ) + } + restartSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().RunCommandAndWait(ctx, gomock.Any(), gomock.Any(), instanceID, gomock.Any()).Return(nil) + } + healthyInstanceView := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), gomock.Any(), gomock.Any(), instanceID).Return(healthyVMSS, nil) + } + + for _, tt := range []struct { + name string + secretToFind string + mocks []mock + wantErr string + }{ + { + name: "return error if ensureAndRotateSecret fails", + mocks: []mock{ + getSecretsFailed, + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret passes without rotating any secret but ensureSecret fails", + mocks: []mock{ + getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsFailed, + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret passes with rotating a missing secret but ensureSecret fails", + mocks: []mock{ + getPartialSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsFailed, + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret, ensureSecret passes without rotating a secret but ensureSecretKey fails", + mocks: []mock{ + getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getSecretsSuccessful, getSecretsFailed, + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret, ensureSecret passes with rotating a legacy secret but ensureSecretKey fails", + mocks: []mock{ + getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsFailed, + }, + wantErr: "generic error", + }, + { + name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes without rotating a secret", + mocks: []mock{ + getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getSecretsSuccessful, getSecretsSuccessful, + }, + }, + { + name: "return error if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret in each ensure function call but restartoldscaleset failing", + mocks: []mock{ + getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, listVMSSFailed, + }, + wantErr: "VM List Failed", + }, + { + name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret in each ensure function call and restartoldscaleset passess successfully", + mocks: []mock{ + getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetsClient(controller) + mockVMSSVM := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + RPResourceGroupName: rgName, + GatewayResourceGroupName: rgName, + }, + serviceKeyvault: mockKV, + portalKeyvault: mockKV, + vmss: mockVMSS, + vmssvms: mockVMSSVM, + } + + for _, m := range tt.mocks { + m(mockKV, mockVMSS, mockVMSSVM) + } + + err := d.configureServiceSecrets(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestEnsureAndRotateSecret(t *testing.T) { + ctx := context.Background() + secretExists := "secretExists" + noSecretExists := "noSecretExists" + secretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: &secretExists, + }, + } + nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) + oldUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Add(-rotateSecretAfter).Unix())) + newSecretBundle := azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{ + Created: &nowUnixTime, + }, + } + + oldSecretBundle := azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{ + Created: &oldUnixTime, + }, + } + + type mock func(*mock_keyvault.MockManager) + getSecretsFailed := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecrets(ctx).Return( + secretItems, errors.New("generic error"), + ) + } + getSecretsSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecrets(ctx).Return( + secretItems, nil, + ) + } + getSecretFailed := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecret(ctx, secretExists).Return( + newSecretBundle, errors.New("generic error"), + ) + } + getNewSecretSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecret(ctx, secretExists).Return( + newSecretBundle, nil, + ) + } + getOldSecretSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecret(ctx, secretExists).Return( + oldSecretBundle, nil, + ) + } + setSecretFails := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( + errors.New("generic error"), + ) + } + setSecretSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + secretToFind string + mocks []mock + wantErr string + wantBool bool + }{ + { + name: "return false and error if GetSecrets fails", + secretToFind: secretExists, + mocks: []mock{ + getSecretsFailed, + }, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and error if GetSecrets passes but GetSecret fails for the found secret", + secretToFind: secretExists, + mocks: []mock{ + getSecretsSuccessful, + getSecretFailed, + }, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and nil if GetSecrets and GetSecret passes and the secret is not too old", + secretToFind: secretExists, + mocks: []mock{ + getSecretsSuccessful, + getNewSecretSuccessful, + }, + wantBool: false, + }, + { + name: "return true and error if GetSecrets & GetSecret passes and the secret is old but new secret creation fails", + secretToFind: secretExists, + mocks: []mock{ + getSecretsSuccessful, + getOldSecretSuccessful, + setSecretFails, + }, + wantBool: true, + wantErr: "generic error", + }, + { + name: "return true and nil if GetSecrets & GetSecret passes and the secret is old and new secret creation passes", + secretToFind: secretExists, + mocks: []mock{ + getSecretsSuccessful, + getOldSecretSuccessful, + setSecretSuccessful, + }, + wantBool: true, + }, + { + name: "return true and nil if the secret is not present and new secret creation passes", + secretToFind: noSecretExists, + mocks: []mock{ + getSecretsSuccessful, + setSecretSuccessful, + }, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV) + } + + got, err := d.ensureAndRotateSecret(ctx, mockKV, tt.secretToFind, 8) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} + +func TestEnsureSecret(t *testing.T) { + ctx := context.Background() + secretExists := "secretExists" + noSecretExists := "noSecretExists" + secretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: &secretExists, + }, + } + + type mock func(*mock_keyvault.MockManager) + getSecretsFailed := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecrets(ctx).Return( + secretItems, errors.New("generic error"), + ) + } + getSecretsSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecrets(ctx).Return( + secretItems, nil, + ) + } + setSecretFails := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( + errors.New("generic error"), + ) + } + setSecretSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + secretToFind string + mocks []mock + wantErr string + wantBool bool + }{ + { + name: "return false and error if GetSecrets fails", + secretToFind: secretExists, + mocks: []mock{ + getSecretsFailed, + }, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and nil if GetSecrets passes and secret is found", + secretToFind: secretExists, + mocks: []mock{ + getSecretsSuccessful, + }, + wantBool: false, + }, + { + name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", + secretToFind: noSecretExists, + mocks: []mock{ + getSecretsSuccessful, setSecretFails, + }, + wantBool: true, + wantErr: "generic error", + }, + { + name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", + secretToFind: noSecretExists, + mocks: []mock{ + getSecretsSuccessful, setSecretSuccessful, + }, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV) + } + + got, err := d.ensureSecret(ctx, mockKV, tt.secretToFind, 8) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} + +func TestCreateSecret(t *testing.T) { + ctx := context.Background() + noSecretExists := "noSecretExists" + + type mock func(*mock_keyvault.MockManager) + setSecretFails := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( + errors.New("generic error"), + ) + } + setSecretSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + secretToCreate string + mocks []mock + wantErr string + }{ + { + name: "return error if new secret creation fails", + secretToCreate: noSecretExists, + mocks: []mock{ + setSecretFails, + }, + wantErr: "generic error", + }, + { + name: "return nil new secret creation passes", + secretToCreate: noSecretExists, + mocks: []mock{ + setSecretSuccessful, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV) + } + + err := d.createSecret(ctx, mockKV, tt.secretToCreate, 8) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestEnsureSecretKey(t *testing.T) { + ctx := context.Background() + secretExists := "secretExists" + noSecretExists := "noSecretExists" + secretItems := []azkeyvault.SecretItem{ + { + ID: to.StringPtr("test1"), + }, + { + ID: &secretExists, + }, + } + + type mock func(*mock_keyvault.MockManager) + getSecretsFailed := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecrets(ctx).Return( + secretItems, errors.New("generic error"), + ) + } + getSecretsSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().GetSecrets(ctx).Return( + secretItems, nil, + ) + } + setSecretFails := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( + errors.New("generic error"), + ) + } + setSecretSuccessful := func(k *mock_keyvault.MockManager) { + k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( + nil, + ) + } + + for _, tt := range []struct { + name string + secretToFind string + mocks []mock + wantErr string + wantBool bool + }{ + { + name: "return false and error if GetSecrets fails", + secretToFind: secretExists, + mocks: []mock{ + getSecretsFailed, + }, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and nil if GetSecrets passes and secret is found", + secretToFind: secretExists, + mocks: []mock{ + getSecretsSuccessful, + }, + wantBool: false, + }, + { + name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", + secretToFind: noSecretExists, + mocks: []mock{ + getSecretsSuccessful, setSecretFails, + }, + wantBool: true, + wantErr: "generic error", + }, + { + name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", + secretToFind: noSecretExists, + mocks: []mock{ + getSecretsSuccessful, setSecretSuccessful, + }, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV) + } + + got, err := d.ensureSecretKey(ctx, mockKV, tt.secretToFind) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} + +func TestRestartOldScalesets(t *testing.T) { + ctx := context.Background() + rgName := "testRG" + rpVMSSName := rpVMSSPrefix + "test" + invalidVMSSName := "other-vmss" + invalidVMSSs := []mgmtcompute.VirtualMachineScaleSet{ + { + Name: to.StringPtr(invalidVMSSName), + }, + } + vmsss := []mgmtcompute.VirtualMachineScaleSet{ + { + Name: to.StringPtr(rpVMSSName), + }, + } + instanceID := "testID" + vms := []mgmtcompute.VirtualMachineScaleSetVM{ + { + InstanceID: to.StringPtr(instanceID), + }, + } + healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/healthy"), + }, + }, + } + + type mock func(*mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient) + listVMSSFailed := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmss.EXPECT().List(ctx, rgName).Return( + vmsss, errors.New("generic error"), + ) + } + invalidVMSSSList := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmss.EXPECT().List(ctx, rgName).Return( + invalidVMSSs, nil, + ) + } + vmssListSuccessful := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmss.EXPECT().List(ctx, rgName).Return( + vmsss, nil, + ) + } + vmssVMsListFailed := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().List(ctx, rgName, rpVMSSName, "", "", "").Return( + vms, errors.New("generic error"), + ) + } + vmssVMsListSuccessful := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().List(ctx, rgName, rpVMSSName, "", "", "").Return( + vms, nil, + ) + } + restartSuccessful := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().RunCommandAndWait(ctx, rgName, rpVMSSName, instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{rpRestartScript}, + }).Return(nil) + } + healthyInstanceView := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), rgName, rpVMSSName, instanceID).Return(healthyVMSS, nil) + } + + for _, tt := range []struct { + name string + resourceGroupName string + mocks []mock + wantErr string + }{ + { + name: "Don't continue if vmss list fails", + resourceGroupName: rgName, + mocks: []mock{listVMSSFailed}, + wantErr: "generic error", + }, + { + name: "Don't continue if vmss list has an invalid vmss name", + resourceGroupName: rgName, + mocks: []mock{invalidVMSSSList}, + wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", + }, + { + name: "Don't continue if vmssvms list fails", + resourceGroupName: rgName, + mocks: []mock{vmssListSuccessful, vmssVMsListFailed}, + wantErr: "generic error", + }, + { + name: "Restart is successful for the VMs in VMSS", + resourceGroupName: rgName, + mocks: []mock{vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetsClient(controller) + mockVMSSVM := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmss: mockVMSS, + vmssvms: mockVMSSVM, + } + + for _, m := range tt.mocks { + m(mockVMSS, mockVMSSVM) + } + + err := d.restartOldScalesets(ctx, tt.resourceGroupName) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestRestartOldScaleset(t *testing.T) { + ctx := context.Background() + otherVMSSName := "other-vmss" + rgName := "testRG" + gwyVMSSName := gatewayVMSSPrefix + "test" + rpVMSSName := rpVMSSPrefix + "test" + instanceID := "testID" + vms := []mgmtcompute.VirtualMachineScaleSetVM{ + { + InstanceID: to.StringPtr(instanceID), + }, + } + healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/healthy"), + }, + }, + } + + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient) + listVMSSFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().List(ctx, rgName, gwyVMSSName, "", "", "").Return( + vms, errors.New("generic error"), + ) + } + listVMSSSuccessful := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().List(ctx, rgName, gomock.Any(), "", "", "").Return( + vms, nil, + ) + } + gatewayRestartFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().RunCommandAndWait(ctx, rgName, gwyVMSSName, instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{gatewayRestartScript}, + }).Return( + errors.New("generic error"), + ) + } + rpRestartFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().RunCommandAndWait(ctx, rgName, rpVMSSName, instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{rpRestartScript}, + }).Return( + errors.New("generic error"), + ) + } + restartSuccessful := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().RunCommandAndWait(ctx, rgName, gomock.Any(), instanceID, gomock.Any()).Return(nil) + } + healthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(gomock.Any(), rgName, gomock.Any(), instanceID).Return(healthyVMSS, nil) + } + for _, tt := range []struct { + name string + vmssName string + resourceGroupName string + mocks []mock + wantErr string + }{ + { + name: "Return an error if the VMSS is not gateway or RP", + vmssName: otherVMSSName, + wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", + }, + { + name: "list VMSS failed", + vmssName: gwyVMSSName, + resourceGroupName: rgName, + mocks: []mock{listVMSSFailed}, + wantErr: "generic error", + }, + { + name: "gateway restart script failed", + vmssName: gwyVMSSName, + resourceGroupName: rgName, + mocks: []mock{listVMSSSuccessful, gatewayRestartFailed}, + wantErr: "generic error", + }, + { + name: "rp restart script failed", + vmssName: rpVMSSName, + resourceGroupName: rgName, + mocks: []mock{listVMSSSuccessful, rpRestartFailed}, + wantErr: "generic error", + }, + { + name: "restart script passes and wait for readiness is successful", + vmssName: rpVMSSName, + resourceGroupName: rgName, + mocks: []mock{listVMSSSuccessful, restartSuccessful, healthyInstanceView}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmssvms: mockVMSS, + } + + for _, m := range tt.mocks { + m(mockVMSS) + } + + err := d.restartOldScaleset(ctx, tt.vmssName, tt.resourceGroupName) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestWaitForReadiness(t *testing.T) { + ctxTimeout, cancel := context.WithTimeout(context.Background(), 11*time.Second) + vmmssName := "testVMSS" + vmInstanceID := "testVMInstanceID" + testRG := "testRG" + unhealthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/unhealthy"), + }, + }, + } + healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/healthy"), + }, + }, + } + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient) + unhealthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(ctxTimeout, testRG, vmmssName, vmInstanceID).Return(unhealthyVMSS, nil).AnyTimes() + } + healthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(ctxTimeout, testRG, vmmssName, vmInstanceID).Return(healthyVMSS, nil) + } + for _, tt := range []struct { + name string + ctx context.Context + cancel context.CancelFunc + vmssName string + vmInstanceID string + resourceGroupName string + mocks []mock + wantErr string + }{ + { + name: "fail after context times out", + ctx: ctxTimeout, + vmssName: vmmssName, + vmInstanceID: vmInstanceID, + resourceGroupName: testRG, + mocks: []mock{ + unhealthyInstanceView, + }, + wantErr: "timed out waiting for the condition", + }, + { + name: "run successfully after confirming healthy status", + ctx: ctxTimeout, + cancel: cancel, + vmssName: vmmssName, + vmInstanceID: vmInstanceID, + resourceGroupName: testRG, + mocks: []mock{ + healthyInstanceView, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmssvms: mockVMSS, + } + + for _, m := range tt.mocks { + m(mockVMSS) + } + + defer cancel() + err := d.waitForReadiness(tt.ctx, tt.resourceGroupName, tt.vmssName, tt.vmInstanceID) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestIsVMInstanceHealthy(t *testing.T) { + ctx := context.Background() + vmmssName := "testVMSS" + vmInstanceID := "testVMInstanceID" + rpRGName := "testRPRG" + gatewayRGName := "testGatewayRG" + unhealthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/unhealthy"), + }, + }, + } + healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/healthy"), + }, + }, + } + + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient) + getRPInstanceViewFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(ctx, rpRGName, vmmssName, vmInstanceID).Return( + unhealthyVMSS, errors.New("generic error"), + ) + } + getGatewayInstanceViewFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(ctx, gatewayRGName, vmmssName, vmInstanceID).Return( + unhealthyVMSS, errors.New("generic error"), + ) + } + unhealthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(ctx, gatewayRGName, vmmssName, vmInstanceID).Return(unhealthyVMSS, nil) + } + healthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { + c.EXPECT().GetInstanceView(ctx, gatewayRGName, vmmssName, vmInstanceID).Return(healthyVMSS, nil) + } + for _, tt := range []struct { + name string + vmssName string + vmInstanceID string + resourceGroupName string + mocks []mock + wantBool bool + }{ + { + name: "return false if GetInstanceView failed for RP resource group", + vmssName: vmmssName, + vmInstanceID: vmInstanceID, + resourceGroupName: rpRGName, + mocks: []mock{ + getRPInstanceViewFailed, + }, + wantBool: false, + }, + { + name: "return false if GetInstanceView failed for Gateway resource group", + vmssName: vmmssName, + vmInstanceID: vmInstanceID, + resourceGroupName: gatewayRGName, + mocks: []mock{ + getGatewayInstanceViewFailed, + }, + wantBool: false, + }, + { + name: "return false if GetInstanceView return unhealthy VM", + vmssName: vmmssName, + vmInstanceID: vmInstanceID, + resourceGroupName: gatewayRGName, + mocks: []mock{ + unhealthyInstanceView, + }, + wantBool: false, + }, + { + name: "return true if GetInstanceView return healthy VM", + vmssName: vmmssName, + vmInstanceID: vmInstanceID, + resourceGroupName: gatewayRGName, + mocks: []mock{ + healthyInstanceView, + }, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmssvms: mockVMSS, + } + + for _, m := range tt.mocks { + m(mockVMSS) + } + + got := d.isVMInstanceHealthy(ctx, tt.resourceGroupName, tt.vmssName, tt.vmInstanceID) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} diff --git a/pkg/deploy/upgrade_gateway.go b/pkg/deploy/upgrade_gateway.go index 3af14fb1556..bd7c35f8820 100644 --- a/pkg/deploy/upgrade_gateway.go +++ b/pkg/deploy/upgrade_gateway.go @@ -40,7 +40,7 @@ func (d *deployer) gatewayWaitForReadiness(ctx context.Context, vmssName string) d.log.Printf("waiting for %s instances to be healthy", vmssName) return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { for _, vm := range scalesetVMs { - if !d.isVMInstanceHealthy(ctx, vmssName, *vm.InstanceID) { + if !d.isVMInstanceHealthy(ctx, d.config.GatewayResourceGroupName, vmssName, *vm.InstanceID) { return false, nil } } diff --git a/pkg/deploy/upgrade_rp.go b/pkg/deploy/upgrade_rp.go index 3bf970a432d..d03c474808d 100644 --- a/pkg/deploy/upgrade_rp.go +++ b/pkg/deploy/upgrade_rp.go @@ -40,7 +40,7 @@ func (d *deployer) rpWaitForReadiness(ctx context.Context, vmssName string) erro d.log.Printf("waiting for %s instances to be healthy", vmssName) return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { for _, vm := range scalesetVMs { - if !d.isVMInstanceHealthy(ctx, vmssName, *vm.InstanceID) { + if !d.isVMInstanceHealthy(ctx, d.config.RPResourceGroupName, vmssName, *vm.InstanceID) { return false, nil } } diff --git a/pkg/util/azureclient/mgmt/msi/generate.go b/pkg/util/azureclient/mgmt/msi/generate.go new file mode 100644 index 00000000000..f472fe1efcd --- /dev/null +++ b/pkg/util/azureclient/mgmt/msi/generate.go @@ -0,0 +1,8 @@ +package msi + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +//go:generate rm -rf ../../../../util/mocks/$GOPACKAGE +//go:generate go run ../../../../../vendor/github.com/golang/mock/mockgen -destination=../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/$GOPACKAGE UserAssignedIdentitiesClient +//go:generate go run ../../../../../vendor/golang.org/x/tools/cmd/goimports -local=github.com/Azure/ARO-RP -e -w ../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go diff --git a/pkg/util/mocks/azureclient/mgmt/msi/msi.go b/pkg/util/mocks/azureclient/mgmt/msi/msi.go new file mode 100644 index 00000000000..d66bb62005a --- /dev/null +++ b/pkg/util/mocks/azureclient/mgmt/msi/msi.go @@ -0,0 +1,51 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/msi (interfaces: UserAssignedIdentitiesClient) + +// Package mock_msi is a generated GoMock package. +package mock_msi + +import ( + context "context" + reflect "reflect" + + msi "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" + gomock "github.com/golang/mock/gomock" +) + +// MockUserAssignedIdentitiesClient is a mock of UserAssignedIdentitiesClient interface. +type MockUserAssignedIdentitiesClient struct { + ctrl *gomock.Controller + recorder *MockUserAssignedIdentitiesClientMockRecorder +} + +// MockUserAssignedIdentitiesClientMockRecorder is the mock recorder for MockUserAssignedIdentitiesClient. +type MockUserAssignedIdentitiesClientMockRecorder struct { + mock *MockUserAssignedIdentitiesClient +} + +// NewMockUserAssignedIdentitiesClient creates a new mock instance. +func NewMockUserAssignedIdentitiesClient(ctrl *gomock.Controller) *MockUserAssignedIdentitiesClient { + mock := &MockUserAssignedIdentitiesClient{ctrl: ctrl} + mock.recorder = &MockUserAssignedIdentitiesClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserAssignedIdentitiesClient) EXPECT() *MockUserAssignedIdentitiesClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockUserAssignedIdentitiesClient) Get(arg0 context.Context, arg1, arg2 string) (msi.Identity, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) + ret0, _ := ret[0].(msi.Identity) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockUserAssignedIdentitiesClientMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockUserAssignedIdentitiesClient)(nil).Get), arg0, arg1, arg2) +} From 5e01c17a6d708d4c6903a510cca17a2824663dbd Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Sat, 22 Jul 2023 21:07:20 -0400 Subject: [PATCH 7/8] refactor predeploy.go unit test cases --- pkg/deploy/predeploy_test.go | 1957 ++++++++++++++++------------------ 1 file changed, 924 insertions(+), 1033 deletions(-) diff --git a/pkg/deploy/predeploy_test.go b/pkg/deploy/predeploy_test.go index a02aa3c885e..981dce01488 100644 --- a/pkg/deploy/predeploy_test.go +++ b/pkg/deploy/predeploy_test.go @@ -33,74 +33,33 @@ import ( func TestPreDeploy(t *testing.T) { ctx := context.Background() + location := "testLocation" subscriptionRgName := "testRG-subscription" globalRgName := "testRG-global" rpRgName := "testRG-aro-rp" gatewayRgName := "testRG-gwy" - location := "testLocation" overrideLocation := "overrideTestLocation" group := mgmtfeatures.ResourceGroup{ Location: &location, } fakeMSIObjectId, _ := gofrsuuid.NewV4() msi := mgmtmsi.Identity{ - UserAssignedIdentityProperties: &mgmtmsi.UserAssignedIdentityProperties{ - PrincipalID: &fakeMSIObjectId, - }, + UserAssignedIdentityProperties: &mgmtmsi.UserAssignedIdentityProperties{PrincipalID: &fakeMSIObjectId}, } deployment := mgmtfeatures.DeploymentExtended{} - partialSecretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: to.StringPtr(env.EncryptionSecretV2Name), - }, - { - ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), - }, - } - rpVMSSName := rpVMSSPrefix + "test" + vmssName := rpVMSSPrefix + "test" nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) newSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{ - Created: &nowUnixTime, - }, + Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, } - vmsss := []mgmtcompute.VirtualMachineScaleSet{ - { - Name: to.StringPtr(rpVMSSName), - }, - } - allSecretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: to.StringPtr(env.EncryptionSecretV2Name), - }, - { - ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), - }, - { - ID: to.StringPtr(env.PortalServerSessionKeySecretName), - }, - { - ID: to.StringPtr(env.EncryptionSecretName), - }, - { - ID: to.StringPtr(env.FrontendEncryptionSecretName), - }, - { - ID: to.StringPtr(env.PortalServerSSHKeySecretName), - }, + vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: &vmssName}} + oneMissingSecrets := []string{env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} + oneMissingSecretItems := []azkeyvault.SecretItem{} + for _, secret := range oneMissingSecrets { + oneMissingSecretItems = append(oneMissingSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) } instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{ - { - InstanceID: to.StringPtr(instanceID), - }, - } + vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: to.StringPtr(instanceID)}} healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ Status: &mgmtcompute.InstanceViewStatus{ @@ -108,345 +67,369 @@ func TestPreDeploy(t *testing.T) { }, }, } - - type mock func(*mock_features.MockDeploymentsClient, *mock_features.MockResourceGroupsClient, *mock_msi.MockUserAssignedIdentitiesClient, *mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient) - genericSubscriptionDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ).AnyTimes() - } - subscriptionDeploymentSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return(nil) - } - subscriptionRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, subscriptionRgName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) - } - globalRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, globalRgName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) - } - gatewayRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, gatewayRgName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) - } - rpRGDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, rpRgName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) - } - deploymentSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - } - subscriptionResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - rg.EXPECT().CreateOrUpdate(ctx, subscriptionRgName, gomock.Any()).Return( - group, - errors.New("generic error"), - ) - } - globalResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - rg.EXPECT().CreateOrUpdate(ctx, globalRgName, gomock.Any()).Return( - group, - errors.New("generic error"), - ) - } - rpResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - rg.EXPECT().CreateOrUpdate(ctx, rpRgName, gomock.Any()).Return( - group, - errors.New("generic error"), - ) - } - gatewayResourceGroupDeploymentFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - rg.EXPECT().CreateOrUpdate(ctx, gatewayRgName, gomock.Any()).Return( - group, - errors.New("generic error"), - ) - } - resourceGroupDeploymentSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - rg.EXPECT().CreateOrUpdate(ctx, gomock.Any(), gomock.Any()).Return(group, nil) - } - rpMSIGetFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - m.EXPECT().Get(ctx, rpRgName, gomock.Any()).Return(msi, errors.New("generic error")) - } - rpMSIGetSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - m.EXPECT().Get(ctx, rpRgName, gomock.Any()).Return(msi, nil) - } - gatewayMSIGetFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - m.EXPECT().Get(ctx, gatewayRgName, gomock.Any()).Return(msi, errors.New("generic error")) - } - gatewayMSIGetSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - m.EXPECT().Get(ctx, gatewayRgName, gomock.Any()).Return(msi, nil) - } - getDeploymentFailedWithDeploymentNotFound := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - d.EXPECT().Get(ctx, gatewayRgName, gomock.Any()).Return(deployment, autorest.DetailedError{ - Original: &azure.RequestError{ - ServiceError: &azure.ServiceError{ - Code: "DeploymentNotFound", - Details: []map[string]interface{}{ - {}, - }, + deploymentNotFoundError := autorest.DetailedError{ + Original: &azure.RequestError{ + ServiceError: &azure.ServiceError{ + Code: "DeploymentNotFound", + Details: []map[string]interface{}{ + {}, }, }, - }) - } - getSecretsFailed := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecrets(ctx).Return( - partialSecretItems, errors.New("generic error"), - ) - } - getSecretsSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecrets(ctx).Return( - allSecretItems, nil, - ) - } - getNewSecretSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecret(ctx, gomock.Any()).Return( - newSecretBundle, nil, - ) - } - getPartialSecretsSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecrets(ctx).Return( - partialSecretItems, nil, - ) - } - setSecretSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( - nil, - ) - } - vmssListSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmss.EXPECT().List(ctx, gomock.Any()).Return( - vmsss, nil, - ) - } - vmssVMsListSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().List(ctx, gomock.Any(), gomock.Any(), "", "", "").Return( - vms, nil, - ) - } - restartSuccessful := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().RunCommandAndWait(ctx, gomock.Any(), gomock.Any(), instanceID, gomock.Any()).Return(nil) + }, } - healthyInstanceView := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().GetInstanceView(gomock.Any(), gomock.Any(), gomock.Any(), instanceID).Return(healthyVMSS, nil) + deploymentFailedError := &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{ + {}, + }, } + genericError := errors.New("generic error") - for _, tt := range []struct { - name string - location string - overrideLocation string - acrReplicaDisabled bool + type resourceGroups struct { subscriptionRgName string globalResourceGroup string rpResourceGroupName string gatewayResourceGroupName string - mocks []mock - wantErr string + } + type testParams struct { + resourceGroups resourceGroups + location string + instanceID string + vmssName string + restartScript string + overrideLocation string + acrReplicaDisabled bool + } + type mock func(*mock_features.MockDeploymentsClient, *mock_features.MockResourceGroupsClient, *mock_msi.MockUserAssignedIdentitiesClient, *mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + createOrUpdateAtSubscriptionScopeAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, "rp-global-subscription-"+tp.location, gomock.Any()).Return(returnError) + } + } + createOrUpdateAndWaitMock := func(resourceGroup string, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, resourceGroup, gomock.Any(), gomock.Any()).Return(returnError) + } + } + createOrUpdateMock := func(resourceGroup string, returnResourceGroup mgmtfeatures.ResourceGroup, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + rg.EXPECT().CreateOrUpdate(ctx, resourceGroup, mgmtfeatures.ResourceGroup{Location: &tp.location}).Return(returnResourceGroup, returnError) + } + } + msiGetMock := func(resourceGroup string, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + m.EXPECT().Get(ctx, resourceGroup, gomock.Any()).Return(msi, returnError) + } + } + getDeploymentMock := func(resourceGroup string, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + d.EXPECT().Get(ctx, resourceGroup, gomock.Any()).Return(deployment, returnError) + } + } + getSecretsMock := func(secretItems []azkeyvault.SecretItem, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } + } + getSecretMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecret(ctx, gomock.Any()).Return(newSecretBundle, nil) + } + setSecretMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return(nil) + } + vmssListMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmss.EXPECT().List(ctx, gomock.Any()).Return(vmsss, nil) + } + vmssVMsListMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().List(ctx, gomock.Any(), tp.vmssName, "", "", "").Return(vms, nil) + } + vmRestartMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().RunCommandAndWait(ctx, gomock.Any(), tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(nil) + } + instanceViewMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), gomock.Any(), tp.vmssName, tp.instanceID).Return(healthyVMSS, nil) + } + + for _, tt := range []struct { + name string + acrReplicaDisabled bool + testParams testParams + mocks []mock + wantErr string }{ { - name: "don't continue if Global Subscription RBAC DeploymentFailed", - location: location, + name: "don't continue if Global Subscription RBAC DeploymentFailed", + testParams: testParams{ + location: location, + }, mocks: []mock{ - genericSubscriptionDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment is Successful but SubscriptionResourceGroup creation fails", - location: location, - subscriptionRgName: subscriptionRgName, + name: "don't continue if Global Subscription RBAC Deployment is Successful but SubscriptionResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, subscriptionResourceGroupDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment is Successful but GlobalResourceGroup creation fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + name: "don't continue if SubscriptionResourceGroup creation is Successful but GlobalResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, globalResourceGroupDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment is Successful but RPResourceGroup creation fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, + name: "don't continue if GlobalResourceGroup creation is Successful but RPResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, rpResourceGroupDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment is successful but GatewayResourceGroup creation fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if RPResourceGroup creation is successful but GatewayResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, gatewayResourceGroupDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment & resource group creation is successful but rp-subscription template deployment fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if GatewayResourceGroup is successful but rp-subscription template deployment fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, subscriptionRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment is successful but rp managed identity get fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if rp-subscription template deployment is successful but rp managed identity get fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment is successful but gateway managed identity get fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if rp managed identity get is successful but gateway managed identity get fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment, msi get is successful but rpglobal deployment get fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if rpglobal deployment fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, globalRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if Global Subscription RBAC Deployment, resource group creation and rp-subscription template deployment, msi get is successful but rpglobal deployment get fails", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if rpglobal deployment fails twice with DeploymentFailed", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, globalRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, deploymentFailedError), createOrUpdateAndWaitMock(globalRgName, deploymentFailedError), }, - wantErr: "generic error", + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, }, { - name: "don't continue if Global Subscription RBAC Deployment, resource group creation, rp-subscription deployment, rpglobal deployment is successful but ACR Replication fails", - location: location, - overrideLocation: overrideLocation, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if ACR Replication fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, globalRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), createOrUpdateAndWaitMock(globalRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if skipping ACR Replication due to no ACRLocationOverride but failing gateway predeploy", - location: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if skipping ACR Replication due to no ACRLocationOverride but failing gateway predeploy", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, gatewayRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if skipping ACR Replication due to same ACRLocationOverride as location but failing gateway predeploy", - location: location, - overrideLocation: location, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if skipping ACR Replication due to ACRLocationOverride same as GlobalResourceGroupLocation but failing gateway predeploy", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: location, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, gatewayRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue if skipping ACR Replication due to ACRReplicaDisabled but failing gateway predeploy", - location: location, - overrideLocation: overrideLocation, - acrReplicaDisabled: true, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue if skipping ACR Replication due to ACRReplicaDisabled but failing gateway predeploy", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, gatewayRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, genericError), }, wantErr: "generic error", }, { - name: "don't continue gateway predeploy is successful but rp predeploy failed", - location: location, - overrideLocation: overrideLocation, - acrReplicaDisabled: true, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "don't continue gateway predeploy is successful but rp predeploy failed", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, deploymentSuccessful, rpRGDeploymentFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, genericError), }, wantErr: "generic error", }, { - name: "get error for the configureServiceSecrets", - location: location, - overrideLocation: overrideLocation, - acrReplicaDisabled: true, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "get error for the configureServiceSecrets", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, deploymentSuccessful, deploymentSuccessful, getSecretsFailed, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, genericError), }, wantErr: "generic error", }, { - name: "Everything is successful", - location: location, - overrideLocation: overrideLocation, - acrReplicaDisabled: true, - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, - rpResourceGroupName: rpRgName, - gatewayResourceGroupName: gatewayRgName, + name: "Everything is successful", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRgName: subscriptionRgName, + globalResourceGroup: globalRgName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + vmssName: vmssName, + instanceID: instanceID, + restartScript: rpRestartScript, + }, mocks: []mock{ - subscriptionDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, resourceGroupDeploymentSuccessful, deploymentSuccessful, deploymentSuccessful, rpMSIGetSuccessful, deploymentSuccessful, gatewayMSIGetSuccessful, deploymentSuccessful, getDeploymentFailedWithDeploymentNotFound, deploymentSuccessful, deploymentSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, }, }, } { @@ -470,16 +453,16 @@ func TestPreDeploy(t *testing.T) { userassignedidentities: mockMSIs, config: &RPConfig{ Configuration: &Configuration{ - GlobalResourceGroupLocation: &tt.location, - SubscriptionResourceGroupLocation: &tt.location, - SubscriptionResourceGroupName: &tt.subscriptionRgName, - GlobalResourceGroupName: &tt.globalResourceGroup, - ACRLocationOverride: &tt.overrideLocation, - ACRReplicaDisabled: &tt.acrReplicaDisabled, + GlobalResourceGroupLocation: &tt.testParams.location, + SubscriptionResourceGroupLocation: &tt.testParams.location, + SubscriptionResourceGroupName: &tt.testParams.resourceGroups.subscriptionRgName, + GlobalResourceGroupName: &tt.testParams.resourceGroups.globalResourceGroup, + ACRLocationOverride: &tt.testParams.overrideLocation, + ACRReplicaDisabled: &tt.testParams.acrReplicaDisabled, }, - RPResourceGroupName: tt.rpResourceGroupName, - GatewayResourceGroupName: tt.gatewayResourceGroupName, - Location: tt.location, + RPResourceGroupName: tt.testParams.resourceGroups.rpResourceGroupName, + GatewayResourceGroupName: tt.testParams.resourceGroups.gatewayResourceGroupName, + Location: tt.testParams.location, }, serviceKeyvault: mockKV, portalKeyvault: mockKV, @@ -488,7 +471,7 @@ func TestPreDeploy(t *testing.T) { } for _, m := range tt.mocks { - m(mockDeployments, mockResourceGroups, mockMSIs, mockKV, mockVMSS, mockVMSSVM) + m(mockDeployments, mockResourceGroups, mockMSIs, mockKV, mockVMSS, mockVMSSVM, tt.testParams) } err := d.PreDeploy(ctx) @@ -500,46 +483,46 @@ func TestPreDeploy(t *testing.T) { func TestDeployRPGlobalSubscription(t *testing.T) { ctx := context.Background() location := "locationTest" - - type mock func(*mock_features.MockDeploymentsClient) - subscriptionDeploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ).AnyTimes() + deploymentFailedError := &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{ + {}, + }, } - subscriptionDeploymentFailedWithDeploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return( - &azure.ServiceError{ - Code: "DeploymentFailed", - Details: []map[string]interface{}{ - {}, - }, - }, - ) + genericError := errors.New("generic error") + + type testParams struct { + location string } - subscriptionDeploymentSuccessful := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, gomock.Any(), gomock.Any()).Return(nil) + type mock func(*mock_features.MockDeploymentsClient, testParams) + createOrUpdateAtSubscriptionScopeAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, "rp-global-subscription-"+tp.location, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - deploymentFileName string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Don't continue if deployment fails with error other than DeploymentFailed", - mocks: []mock{subscriptionDeploymentFailed}, - wantErr: "generic error", + name: "Don't continue if deployment fails with error other than DeploymentFailed", + testParams: testParams{location: location}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(genericError)}, + wantErr: "generic error", }, { - name: "Don't continue if deployment fails with error DeploymentFailed five times", - mocks: []mock{subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentFailedWithDeploymentFailed}, - wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, + name: "Don't continue if deployment fails with error DeploymentFailed five times", + testParams: testParams{location: location}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError)}, + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, }, { - name: "Pass successfully when deployment is successfulin second attempt", - mocks: []mock{subscriptionDeploymentFailedWithDeploymentFailed, subscriptionDeploymentSuccessful}, + name: "Pass successfully when deployment is successfulin second attempt", + testParams: testParams{location: location}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -552,15 +535,15 @@ func TestDeployRPGlobalSubscription(t *testing.T) { log: logrus.NewEntry(logrus.StandardLogger()), config: &RPConfig{ Configuration: &Configuration{ - GlobalResourceGroupLocation: &location, + GlobalResourceGroupLocation: &tt.testParams.location, }, - Location: location, + Location: tt.testParams.location, }, globaldeployments: mockDeployments, } for _, m := range tt.mocks { - m(mockDeployments) + m(mockDeployments, tt.testParams) } err := d.deployRPGlobalSubscription(ctx) @@ -573,33 +556,41 @@ func TestDeployRPSubscription(t *testing.T) { ctx := context.Background() location := "locationTest" subscriptionRGName := "rgTest" + genericError := errors.New("generic error") - type mock func(*mock_features.MockDeploymentsClient) - deploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, subscriptionRGName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) + type testParams struct { + resourceGroup string + location string } - deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, subscriptionRGName, gomock.Any(), gomock.Any()).Return( - nil, - ) + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, "rp-production-subscription-"+tp.location, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - deploymentFileName string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Don't continue if deployment fails", - mocks: []mock{deploymentFailed}, + name: "Don't continue if deployment fails", + testParams: testParams{ + location: location, + resourceGroup: subscriptionRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, wantErr: "generic error", }, { - name: "Pass successfully when deployment is successful", - mocks: []mock{deploymentSuccess}, + name: "Pass successfully when deployment is successful", + testParams: testParams{ + location: location, + resourceGroup: subscriptionRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -612,15 +603,15 @@ func TestDeployRPSubscription(t *testing.T) { log: logrus.NewEntry(logrus.StandardLogger()), config: &RPConfig{ Configuration: &Configuration{ - SubscriptionResourceGroupName: &subscriptionRGName, + SubscriptionResourceGroupName: &tt.testParams.resourceGroup, }, - Location: location, + Location: tt.testParams.location, }, deployments: mockDeployments, } for _, m := range tt.mocks { - m(mockDeployments) + m(mockDeployments, tt.testParams) } err := d.deployRPSubscription(ctx) @@ -635,40 +626,51 @@ func TestDeployManagedIdentity(t *testing.T) { existingFileName := generator.FileGatewayProductionPredeploy deploymentName := strings.TrimSuffix(existingFileName, ".json") notExistingFileName := "testFile" + genericError := errors.New("generic error") - type mock func(*mock_features.MockDeploymentsClient) - deploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( - errors.New("generic error"), - ) + type testParams struct { + resourceGroup string + deploymentFileName string + deploymentName string } - deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( - nil, - ) + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, tp.deploymentName, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - deploymentFileName string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Don't continue if deployment file does not exist", - deploymentFileName: notExistingFileName, - wantErr: "open " + notExistingFileName + ": file does not exist", + name: "Don't continue if deployment file does not exist", + testParams: testParams{ + deploymentFileName: notExistingFileName, + }, + wantErr: "open " + notExistingFileName + ": file does not exist", }, { - name: "Don't continue if deployment fails", - deploymentFileName: existingFileName, - mocks: []mock{deploymentFailed}, - wantErr: "generic error", + name: "Don't continue if deployment fails", + testParams: testParams{ + deploymentFileName: existingFileName, + deploymentName: deploymentName, + resourceGroup: rgName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + wantErr: "generic error", }, { - name: "Pass successfully when deployment is successful", - deploymentFileName: existingFileName, - mocks: []mock{deploymentSuccess}, + name: "Pass successfully when deployment is successful", + testParams: testParams{ + deploymentFileName: existingFileName, + deploymentName: deploymentName, + resourceGroup: rgName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -686,10 +688,10 @@ func TestDeployManagedIdentity(t *testing.T) { } for _, m := range tt.mocks { - m(mockDeployments) + m(mockDeployments, tt.testParams) } - err := d.deployManagedIdentity(ctx, rgName, tt.deploymentFileName) + err := d.deployManagedIdentity(ctx, tt.testParams.resourceGroup, tt.testParams.deploymentFileName) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -701,47 +703,64 @@ func TestDeployRPGlobal(t *testing.T) { globalRGName := "globalRGTest" rpSPID := "rpSPIDTest" gwySPID := "gwySPIDTest" - - type mock func(*mock_features.MockDeploymentsClient) - deploymentFailedWithGenericError := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) + deploymentFailedError := &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{ + {}, + }, } - deploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( - &azure.ServiceError{ - Code: "DeploymentFailed", - Details: []map[string]interface{}{ - {}, - }, - }, - ) + genericError := errors.New("generic error") + + type testParams struct { + resourceGroup string + location string + rpSPID string + gwySPID string } - deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( - nil, - ) + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, "rp-global-"+tp.location, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Don't continue if deployment fails with error other than DeploymentFailed", - mocks: []mock{deploymentFailedWithGenericError}, + name: "Don't continue if deployment fails with error other than DeploymentFailed", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + rpSPID: rpSPID, + gwySPID: gwySPID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, wantErr: "generic error", }, { - name: "Don't continue if deployment fails with DeploymentFailed error twice", - mocks: []mock{deploymentFailed, deploymentFailed}, + name: "Don't continue if deployment fails with DeploymentFailed error twice", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + rpSPID: rpSPID, + gwySPID: gwySPID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(deploymentFailedError), CreateOrUpdateAndWaitMock(deploymentFailedError)}, wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, }, { - name: "Pass successfully when deployment is successful in second attempt", - mocks: []mock{deploymentFailed, deploymentSuccess}, + name: "Pass successfully when deployment is successful in second attempt", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + rpSPID: rpSPID, + gwySPID: gwySPID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(deploymentFailedError), CreateOrUpdateAndWaitMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -754,18 +773,18 @@ func TestDeployRPGlobal(t *testing.T) { log: logrus.NewEntry(logrus.StandardLogger()), config: &RPConfig{ Configuration: &Configuration{ - GlobalResourceGroupName: to.StringPtr(globalRGName), + GlobalResourceGroupName: to.StringPtr(tt.testParams.resourceGroup), }, - Location: location, + Location: tt.testParams.location, }, globaldeployments: mockDeployments, } for _, m := range tt.mocks { - m(mockDeployments) + m(mockDeployments, tt.testParams) } - err := d.deployRPGlobal(ctx, rpSPID, gwySPID) + err := d.deployRPGlobal(ctx, tt.testParams.rpSPID, tt.testParams.gwySPID) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -775,32 +794,41 @@ func TestDeployRPGlobalACRReplication(t *testing.T) { ctx := context.Background() globalRGName := "globalRGTest" location := "testLocation" + genericError := errors.New("generic error") - type mock func(*mock_features.MockDeploymentsClient) - deploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) + type testParams struct { + resourceGroup string + location string } - deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, globalRGName, gomock.Any(), gomock.Any()).Return( - nil, - ) + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, "rp-global-acr-replication-"+tp.location, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Don't continue if deployment fails", - mocks: []mock{deploymentFailed}, + name: "Don't continue if deployment fails", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, wantErr: "generic error", }, { - name: "Pass when deployment is successful", - mocks: []mock{deploymentSuccess}, + name: "Pass when deployment is successful", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -813,15 +841,15 @@ func TestDeployRPGlobalACRReplication(t *testing.T) { log: logrus.NewEntry(logrus.StandardLogger()), config: &RPConfig{ Configuration: &Configuration{ - GlobalResourceGroupName: to.StringPtr(globalRGName), + GlobalResourceGroupName: to.StringPtr(tt.testParams.resourceGroup), }, - Location: location, + Location: tt.testParams.location, }, globaldeployments: mockDeployments, } for _, m := range tt.mocks { - m(mockDeployments) + m(mockDeployments, tt.testParams) } err := d.deployRPGlobalACRReplication(ctx) @@ -833,59 +861,66 @@ func TestDeployRPGlobalACRReplication(t *testing.T) { func TestDeployPreDeploy(t *testing.T) { ctx := context.Background() rgName := "testRG" - gwyRGName := "testGwyRG" existingFileName := generator.FileGatewayProductionPredeploy deploymentName := strings.TrimSuffix(existingFileName, ".json") notExistingFileName := "testFile" spIDName := "testSPIDName" spID := "testSPID" + genericError := errors.New("generic error") - type mock func(*mock_features.MockDeploymentsClient) - deploymentFailed := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( - errors.New("generic error"), - ) - } - deploymentSuccess := func(d *mock_features.MockDeploymentsClient) { - d.EXPECT().CreateOrUpdateAndWait(ctx, rgName, deploymentName, gomock.Any()).Return( - nil, - ) - } - - for _, tt := range []struct { - name string - resourceGroupName string + type testParams struct { + resourceGroup string deploymentFileName string + deploymentName string spIDName string spID string isCreate bool - mocks []mock - wantErr string + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, tp.deploymentName, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Don't continue if deployment file does not exist", - resourceGroupName: rgName, - deploymentFileName: notExistingFileName, - spIDName: spIDName, - spID: spID, - wantErr: "open " + notExistingFileName + ": file does not exist", + name: "Don't continue if deployment file does not exist", + testParams: testParams{ + resourceGroup: rgName, + deploymentFileName: notExistingFileName, + spIDName: spIDName, + spID: spID, + }, + wantErr: "open " + notExistingFileName + ": file does not exist", }, { - name: "Don't continue if deployment fails", - resourceGroupName: rgName, - deploymentFileName: existingFileName, - spIDName: spIDName, - spID: spID, - mocks: []mock{deploymentFailed}, - wantErr: "generic error", + name: "Don't continue if deployment fails", + testParams: testParams{ + resourceGroup: rgName, + deploymentFileName: existingFileName, + deploymentName: deploymentName, + spIDName: spIDName, + spID: spID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + wantErr: "generic error", }, { - name: "Pass when deployment is successful", - resourceGroupName: rgName, - deploymentFileName: existingFileName, - spIDName: spIDName, - spID: spID, - mocks: []mock{deploymentSuccess}, + name: "Pass when deployment is successful", + testParams: testParams{ + resourceGroup: rgName, + deploymentFileName: existingFileName, + deploymentName: deploymentName, + spIDName: spIDName, + spID: spID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -898,16 +933,16 @@ func TestDeployPreDeploy(t *testing.T) { log: logrus.NewEntry(logrus.StandardLogger()), config: &RPConfig{ Configuration: &Configuration{}, - GatewayResourceGroupName: gwyRGName, + GatewayResourceGroupName: tt.testParams.resourceGroup, }, deployments: mockDeployments, } for _, m := range tt.mocks { - m(mockDeployments) + m(mockDeployments, tt.testParams) } - err := d.deployPreDeploy(ctx, tt.resourceGroupName, tt.deploymentFileName, tt.spIDName, tt.spID, tt.isCreate) + err := d.deployPreDeploy(ctx, tt.testParams.resourceGroup, tt.testParams.deploymentFileName, tt.testParams.spIDName, tt.testParams.spID, tt.testParams.isCreate) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -915,173 +950,131 @@ func TestDeployPreDeploy(t *testing.T) { func TestConfigureServiceSecrets(t *testing.T) { ctx := context.Background() - rpVMSSName := rpVMSSPrefix + "test" + vmssName := rpVMSSPrefix + "test" rgName := "rgTest" nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) newSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{ - Created: &nowUnixTime, - }, + Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, } - vmsss := []mgmtcompute.VirtualMachineScaleSet{ - { - Name: to.StringPtr(rpVMSSName), - }, + vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: to.StringPtr(vmssName)}} + oneMissingSecrets := []string{env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} + oneMissingSecretItems := []azkeyvault.SecretItem{} + for _, secret := range oneMissingSecrets { + oneMissingSecretItems = append(oneMissingSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) } - allSecretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: to.StringPtr(env.EncryptionSecretV2Name), - }, - { - ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), - }, - { - ID: to.StringPtr(env.PortalServerSessionKeySecretName), - }, - { - ID: to.StringPtr(env.EncryptionSecretName), - }, - { - ID: to.StringPtr(env.FrontendEncryptionSecretName), - }, - { - ID: to.StringPtr(env.PortalServerSSHKeySecretName), - }, - } - partialSecretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: to.StringPtr(env.EncryptionSecretV2Name), - }, - { - ID: to.StringPtr(env.FrontendEncryptionSecretV2Name), - }, + allSecrets := []string{env.EncryptionSecretV2Name, env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} + allSecretItems := []azkeyvault.SecretItem{} + for _, secret := range allSecrets { + allSecretItems = append(allSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) } instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{ - { - InstanceID: to.StringPtr(instanceID), - }, - } + vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: to.StringPtr(instanceID)}} healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/healthy"), - }, + Status: &mgmtcompute.InstanceViewStatus{Code: to.StringPtr("HealthState/healthy")}, }, } + genericError := errors.New("generic error") - type mock func(*mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient) - getSecretsFailed := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecrets(ctx).Return( - allSecretItems, errors.New("generic error"), - ) - } - getSecretsSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecrets(ctx).Return( - allSecretItems, nil, - ) + type testParams struct { + vmssName string + instanceID string + resourceGroup string + restartScript string } - getNewSecretSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecret(ctx, gomock.Any()).Return( - newSecretBundle, nil, - ) + type mock func(*mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getSecretsMock := func(secretItems []azkeyvault.SecretItem, returnError error) mock { + return func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } } - getPartialSecretsSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().GetSecrets(ctx).Return( - partialSecretItems, nil, - ) + getSecretMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecret(ctx, gomock.Any()).Return(newSecretBundle, nil) } - setSecretSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( - nil, - ) + setSecretMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return(nil) } - listVMSSFailed := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmss.EXPECT().List(ctx, gomock.Any()).Return( - vmsss, errors.New("VM List Failed"), - ) + vmssListMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmss.EXPECT().List(ctx, tp.resourceGroup).Return(vmsss, returnError).AnyTimes() + } } - vmssListSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmss.EXPECT().List(ctx, gomock.Any()).Return( - vmsss, nil, - ) + vmssVMsListMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().List(ctx, tp.resourceGroup, tp.vmssName, "", "", "").Return(vms, nil).AnyTimes() } - vmssVMsListSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().List(ctx, gomock.Any(), gomock.Any(), "", "", "").Return( - vms, nil, - ) - } - restartSuccessful := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().RunCommandAndWait(ctx, gomock.Any(), gomock.Any(), instanceID, gomock.Any()).Return(nil) + vmRestartMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().RunCommandAndWait(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(nil).AnyTimes() } - healthyInstanceView := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().GetInstanceView(gomock.Any(), gomock.Any(), gomock.Any(), instanceID).Return(healthyVMSS, nil) + instanceViewMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), tp.resourceGroup, tp.vmssName, tp.instanceID).Return(healthyVMSS, nil).AnyTimes() } for _, tt := range []struct { name string secretToFind string + testParams testParams mocks []mock wantErr string }{ { name: "return error if ensureAndRotateSecret fails", mocks: []mock{ - getSecretsFailed, + getSecretsMock(allSecretItems, genericError), }, wantErr: "generic error", }, { name: "return error if ensureAndRotateSecret passes without rotating any secret but ensureSecret fails", mocks: []mock{ - getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsFailed, + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, genericError), }, wantErr: "generic error", }, { name: "return error if ensureAndRotateSecret passes with rotating a missing secret but ensureSecret fails", mocks: []mock{ - getPartialSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsFailed, + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, genericError), }, wantErr: "generic error", }, { name: "return error if ensureAndRotateSecret, ensureSecret passes without rotating a secret but ensureSecretKey fails", mocks: []mock{ - getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getSecretsSuccessful, getSecretsFailed, - }, - wantErr: "generic error", - }, - { - name: "return error if ensureAndRotateSecret, ensureSecret passes with rotating a legacy secret but ensureSecretKey fails", - mocks: []mock{ - getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsFailed, + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, genericError), }, wantErr: "generic error", }, { name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes without rotating a secret", mocks: []mock{ - getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getSecretsSuccessful, getSecretsSuccessful, + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), }, }, { name: "return error if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret in each ensure function call but restartoldscaleset failing", + testParams: testParams{ + vmssName: vmssName, + instanceID: instanceID, + resourceGroup: rgName, + }, mocks: []mock{ - getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, listVMSSFailed, + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), vmssListMock(genericError), }, - wantErr: "VM List Failed", + wantErr: "generic error", }, { - name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret in each ensure function call and restartoldscaleset passess successfully", + name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret and restartoldscaleset passess successfully", + testParams: testParams{ + vmssName: vmssName, + instanceID: instanceID, + resourceGroup: rgName, + restartScript: rpRestartScript, + }, mocks: []mock{ - getSecretsSuccessful, getNewSecretSuccessful, getSecretsSuccessful, getNewSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getSecretsSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, getPartialSecretsSuccessful, setSecretSuccessful, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView, + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), vmssListMock(nil), vmssVMsListMock, vmRestartMock, instanceViewMock, }, }, } { @@ -1096,8 +1089,8 @@ func TestConfigureServiceSecrets(t *testing.T) { d := deployer{ log: logrus.NewEntry(logrus.StandardLogger()), config: &RPConfig{ - RPResourceGroupName: rgName, - GatewayResourceGroupName: rgName, + RPResourceGroupName: tt.testParams.resourceGroup, + GatewayResourceGroupName: tt.testParams.resourceGroup, }, serviceKeyvault: mockKV, portalKeyvault: mockKV, @@ -1106,7 +1099,7 @@ func TestConfigureServiceSecrets(t *testing.T) { } for _, m := range tt.mocks { - m(mockKV, mockVMSS, mockVMSSVM) + m(mockKV, mockVMSS, mockVMSSVM, tt.testParams) } err := d.configureServiceSecrets(ctx) @@ -1119,129 +1112,82 @@ func TestEnsureAndRotateSecret(t *testing.T) { ctx := context.Background() secretExists := "secretExists" noSecretExists := "noSecretExists" - secretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: &secretExists, - }, - } + secretItems := []azkeyvault.SecretItem{{ID: &secretExists}} nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) oldUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Add(-rotateSecretAfter).Unix())) newSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{ - Created: &nowUnixTime, - }, + Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, } - oldSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{ - Created: &oldUnixTime, - }, + Attributes: &azkeyvault.SecretAttributes{Created: &oldUnixTime}, } + genericError := errors.New("generic error") - type mock func(*mock_keyvault.MockManager) - getSecretsFailed := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecrets(ctx).Return( - secretItems, errors.New("generic error"), - ) - } - getSecretsSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecrets(ctx).Return( - secretItems, nil, - ) - } - getSecretFailed := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecret(ctx, secretExists).Return( - newSecretBundle, errors.New("generic error"), - ) - } - getNewSecretSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecret(ctx, secretExists).Return( - newSecretBundle, nil, - ) + type testParams struct { + secretToFind string } - getOldSecretSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecret(ctx, secretExists).Return( - oldSecretBundle, nil, - ) + type mock func(*mock_keyvault.MockManager, testParams) + getSecretsMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } } - setSecretFails := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( - errors.New("generic error"), - ) + getSecretMock := func(secretBundle azkeyvault.SecretBundle, returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecret(ctx, tp.secretToFind).Return(secretBundle, returnError) + } } - setSecretSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return( - nil, - ) + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToFind, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - secretToFind string - mocks []mock - wantErr string - wantBool bool + name string + testParams testParams + mocks []mock + wantErr string + wantBool bool }{ { - name: "return false and error if GetSecrets fails", - secretToFind: secretExists, - mocks: []mock{ - getSecretsFailed, - }, - wantBool: false, - wantErr: "generic error", + name: "return false and error if GetSecrets fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(genericError)}, + wantBool: false, + wantErr: "generic error", }, { - name: "return false and error if GetSecrets passes but GetSecret fails for the found secret", - secretToFind: secretExists, - mocks: []mock{ - getSecretsSuccessful, - getSecretFailed, - }, - wantBool: false, - wantErr: "generic error", + name: "return false and error if GetSecrets passes but GetSecret fails for the found secret", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(newSecretBundle, genericError)}, + wantBool: false, + wantErr: "generic error", }, { - name: "return false and nil if GetSecrets and GetSecret passes and the secret is not too old", - secretToFind: secretExists, - mocks: []mock{ - getSecretsSuccessful, - getNewSecretSuccessful, - }, - wantBool: false, + name: "return false and nil if GetSecrets and GetSecret passes and the secret is not too old", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(newSecretBundle, nil)}, + wantBool: false, }, { - name: "return true and error if GetSecrets & GetSecret passes and the secret is old but new secret creation fails", - secretToFind: secretExists, - mocks: []mock{ - getSecretsSuccessful, - getOldSecretSuccessful, - setSecretFails, - }, - wantBool: true, - wantErr: "generic error", + name: "return true and error if GetSecrets & GetSecret passes and the secret is old but new secret creation fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(oldSecretBundle, nil), setSecretMock(genericError)}, + wantBool: true, + wantErr: "generic error", }, { - name: "return true and nil if GetSecrets & GetSecret passes and the secret is old and new secret creation passes", - secretToFind: secretExists, - mocks: []mock{ - getSecretsSuccessful, - getOldSecretSuccessful, - setSecretSuccessful, - }, - wantBool: true, + name: "return true and nil if GetSecrets & GetSecret passes and the secret is old and new secret creation passes", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(oldSecretBundle, nil), setSecretMock(nil)}, + wantBool: true, }, { - name: "return true and nil if the secret is not present and new secret creation passes", - secretToFind: noSecretExists, - mocks: []mock{ - getSecretsSuccessful, - setSecretSuccessful, - }, - wantBool: true, + name: "return true and nil if the secret is not present and new secret creation passes", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(nil)}, + wantBool: true, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1255,10 +1201,10 @@ func TestEnsureAndRotateSecret(t *testing.T) { } for _, m := range tt.mocks { - m(mockKV) + m(mockKV, tt.testParams) } - got, err := d.ensureAndRotateSecret(ctx, mockKV, tt.secretToFind, 8) + got, err := d.ensureAndRotateSecret(ctx, mockKV, tt.testParams.secretToFind, 8) utilerror.AssertErrorMessage(t, err, tt.wantErr) if tt.wantBool != got { t.Errorf("%#v", got) @@ -1271,77 +1217,56 @@ func TestEnsureSecret(t *testing.T) { ctx := context.Background() secretExists := "secretExists" noSecretExists := "noSecretExists" - secretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: &secretExists, - }, - } + secretItems := []azkeyvault.SecretItem{{ID: &secretExists}} + genericError := errors.New("generic error") - type mock func(*mock_keyvault.MockManager) - getSecretsFailed := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecrets(ctx).Return( - secretItems, errors.New("generic error"), - ) - } - getSecretsSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecrets(ctx).Return( - secretItems, nil, - ) + type testParams struct { + secretToFind string } - setSecretFails := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( - errors.New("generic error"), - ) + type mock func(*mock_keyvault.MockManager, testParams) + getSecretsMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } } - setSecretSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( - nil, - ) + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToFind, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - secretToFind string - mocks []mock - wantErr string - wantBool bool + name string + testParams testParams + mocks []mock + wantErr string + wantBool bool }{ { - name: "return false and error if GetSecrets fails", - secretToFind: secretExists, - mocks: []mock{ - getSecretsFailed, - }, - wantBool: false, - wantErr: "generic error", + name: "return false and error if GetSecrets fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(genericError)}, + wantBool: false, + wantErr: "generic error", }, { - name: "return false and nil if GetSecrets passes and secret is found", - secretToFind: secretExists, - mocks: []mock{ - getSecretsSuccessful, - }, - wantBool: false, + name: "return false and nil if GetSecrets passes and secret is found", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil)}, + wantBool: false, }, { - name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", - secretToFind: noSecretExists, - mocks: []mock{ - getSecretsSuccessful, setSecretFails, - }, - wantBool: true, - wantErr: "generic error", + name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(genericError)}, + wantBool: true, + wantErr: "generic error", }, { - name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", - secretToFind: noSecretExists, - mocks: []mock{ - getSecretsSuccessful, setSecretSuccessful, - }, - wantBool: true, + name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(nil)}, + wantBool: true, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1355,10 +1280,10 @@ func TestEnsureSecret(t *testing.T) { } for _, m := range tt.mocks { - m(mockKV) + m(mockKV, tt.testParams) } - got, err := d.ensureSecret(ctx, mockKV, tt.secretToFind, 8) + got, err := d.ensureSecret(ctx, mockKV, tt.testParams.secretToFind, 8) utilerror.AssertErrorMessage(t, err, tt.wantErr) if tt.wantBool != got { t.Errorf("%#v", got) @@ -1370,39 +1295,38 @@ func TestEnsureSecret(t *testing.T) { func TestCreateSecret(t *testing.T) { ctx := context.Background() noSecretExists := "noSecretExists" + genericError := errors.New("generic error") - type mock func(*mock_keyvault.MockManager) - setSecretFails := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( - errors.New("generic error"), - ) + type testParams struct { + secretToCreate string } - setSecretSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( - nil, - ) + type mock func(*mock_keyvault.MockManager, testParams) + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToCreate, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - secretToCreate string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "return error if new secret creation fails", - secretToCreate: noSecretExists, - mocks: []mock{ - setSecretFails, + name: "return error if new secret creation fails", + testParams: testParams{ + secretToCreate: noSecretExists, }, + mocks: []mock{setSecretMock(genericError)}, wantErr: "generic error", }, { - name: "return nil new secret creation passes", - secretToCreate: noSecretExists, - mocks: []mock{ - setSecretSuccessful, + name: "return nil new secret creation passes", + testParams: testParams{ + secretToCreate: noSecretExists, }, + mocks: []mock{setSecretMock(nil)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1416,10 +1340,10 @@ func TestCreateSecret(t *testing.T) { } for _, m := range tt.mocks { - m(mockKV) + m(mockKV, tt.testParams) } - err := d.createSecret(ctx, mockKV, tt.secretToCreate, 8) + err := d.createSecret(ctx, mockKV, tt.testParams.secretToCreate, 8) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -1429,77 +1353,56 @@ func TestEnsureSecretKey(t *testing.T) { ctx := context.Background() secretExists := "secretExists" noSecretExists := "noSecretExists" - secretItems := []azkeyvault.SecretItem{ - { - ID: to.StringPtr("test1"), - }, - { - ID: &secretExists, - }, - } + secretItems := []azkeyvault.SecretItem{{ID: &secretExists}} + genericError := errors.New("generic error") - type mock func(*mock_keyvault.MockManager) - getSecretsFailed := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecrets(ctx).Return( - secretItems, errors.New("generic error"), - ) - } - getSecretsSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().GetSecrets(ctx).Return( - secretItems, nil, - ) + type testParams struct { + secretToFind string } - setSecretFails := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( - errors.New("generic error"), - ) + type mock func(*mock_keyvault.MockManager, testParams) + getSecretsMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } } - setSecretSuccessful := func(k *mock_keyvault.MockManager) { - k.EXPECT().SetSecret(ctx, noSecretExists, gomock.Any()).Return( - nil, - ) + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToFind, gomock.Any()).Return(returnError) + } } for _, tt := range []struct { - name string - secretToFind string - mocks []mock - wantErr string - wantBool bool + name string + testParams testParams + mocks []mock + wantErr string + wantBool bool }{ { - name: "return false and error if GetSecrets fails", - secretToFind: secretExists, - mocks: []mock{ - getSecretsFailed, - }, - wantBool: false, - wantErr: "generic error", + name: "return false and error if GetSecrets fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(genericError)}, + wantBool: false, + wantErr: "generic error", }, { - name: "return false and nil if GetSecrets passes and secret is found", - secretToFind: secretExists, - mocks: []mock{ - getSecretsSuccessful, - }, - wantBool: false, + name: "return false and nil if GetSecrets passes and secret is found", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil)}, + wantBool: false, }, { - name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", - secretToFind: noSecretExists, - mocks: []mock{ - getSecretsSuccessful, setSecretFails, - }, - wantBool: true, - wantErr: "generic error", + name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(genericError)}, + wantBool: true, + wantErr: "generic error", }, { - name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", - secretToFind: noSecretExists, - mocks: []mock{ - getSecretsSuccessful, setSecretSuccessful, - }, - wantBool: true, + name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(nil)}, + wantBool: true, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1513,10 +1416,10 @@ func TestEnsureSecretKey(t *testing.T) { } for _, m := range tt.mocks { - m(mockKV) + m(mockKV, tt.testParams) } - got, err := d.ensureSecretKey(ctx, mockKV, tt.secretToFind) + got, err := d.ensureSecretKey(ctx, mockKV, tt.testParams.secretToFind) utilerror.AssertErrorMessage(t, err, tt.wantErr) if tt.wantBool != got { t.Errorf("%#v", got) @@ -1530,22 +1433,10 @@ func TestRestartOldScalesets(t *testing.T) { rgName := "testRG" rpVMSSName := rpVMSSPrefix + "test" invalidVMSSName := "other-vmss" - invalidVMSSs := []mgmtcompute.VirtualMachineScaleSet{ - { - Name: to.StringPtr(invalidVMSSName), - }, - } - vmsss := []mgmtcompute.VirtualMachineScaleSet{ - { - Name: to.StringPtr(rpVMSSName), - }, - } + invalidVMSSs := []mgmtcompute.VirtualMachineScaleSet{{Name: &invalidVMSSName}} + vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: &rpVMSSName}} instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{ - { - InstanceID: to.StringPtr(instanceID), - }, - } + vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: &instanceID}} healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ Status: &mgmtcompute.InstanceViewStatus{ @@ -1553,71 +1444,71 @@ func TestRestartOldScalesets(t *testing.T) { }, }, } + genericError := errors.New("generic error") - type mock func(*mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient) - listVMSSFailed := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmss.EXPECT().List(ctx, rgName).Return( - vmsss, errors.New("generic error"), - ) - } - invalidVMSSSList := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmss.EXPECT().List(ctx, rgName).Return( - invalidVMSSs, nil, - ) + type testParams struct { + resourceGroup string + vmssName string + instanceID string + restartScript string } - vmssListSuccessful := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmss.EXPECT().List(ctx, rgName).Return( - vmsss, nil, - ) + type mock func(*mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + listVMSSMock := func(returnVMSS []mgmtcompute.VirtualMachineScaleSet, returnError error) mock { + return func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmss.EXPECT().List(ctx, tp.resourceGroup).Return(returnVMSS, returnError) + } } - vmssVMsListFailed := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().List(ctx, rgName, rpVMSSName, "", "", "").Return( - vms, errors.New("generic error"), - ) + listVMSSVMMock := func(returnError error) mock { + return func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().List(ctx, tp.resourceGroup, tp.vmssName, "", "", "").Return(vms, returnError) + } } - vmssVMsListSuccessful := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().List(ctx, rgName, rpVMSSName, "", "", "").Return( - vms, nil, - ) - } - restartSuccessful := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().RunCommandAndWait(ctx, rgName, rpVMSSName, instanceID, mgmtcompute.RunCommandInput{ + vmRestartMock := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().RunCommandAndWait(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ CommandID: to.StringPtr("RunShellScript"), - Script: &[]string{rpRestartScript}, + Script: &[]string{tp.restartScript}, }).Return(nil) } - healthyInstanceView := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient) { - vmssvms.EXPECT().GetInstanceView(gomock.Any(), rgName, rpVMSSName, instanceID).Return(healthyVMSS, nil) + getInstanceViewMock := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), tp.resourceGroup, tp.vmssName, tp.instanceID).Return(healthyVMSS, nil) } for _, tt := range []struct { - name string - resourceGroupName string - mocks []mock - wantErr string + name string + mocks []mock + testParams testParams + wantErr string }{ { - name: "Don't continue if vmss list fails", - resourceGroupName: rgName, - mocks: []mock{listVMSSFailed}, - wantErr: "generic error", + name: "Don't continue if vmss list fails", + testParams: testParams{resourceGroup: rgName}, + mocks: []mock{listVMSSMock(vmsss, genericError)}, + wantErr: "generic error", }, { - name: "Don't continue if vmss list has an invalid vmss name", - resourceGroupName: rgName, - mocks: []mock{invalidVMSSSList}, - wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", + name: "Don't continue if vmss list has an invalid vmss name", + testParams: testParams{resourceGroup: rgName}, + mocks: []mock{listVMSSMock(invalidVMSSs, nil)}, + wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", }, { - name: "Don't continue if vmssvms list fails", - resourceGroupName: rgName, - mocks: []mock{vmssListSuccessful, vmssVMsListFailed}, - wantErr: "generic error", + name: "Don't continue if vmssvms list fails", + testParams: testParams{ + resourceGroup: rgName, + vmssName: rpVMSSName, + }, + mocks: []mock{listVMSSMock(vmsss, nil), listVMSSVMMock(genericError)}, + wantErr: "generic error", }, { - name: "Restart is successful for the VMs in VMSS", - resourceGroupName: rgName, - mocks: []mock{vmssListSuccessful, vmssVMsListSuccessful, restartSuccessful, healthyInstanceView}, + name: "Restart is successful for the VMs in VMSS", + testParams: testParams{ + resourceGroup: rgName, + vmssName: rpVMSSName, + instanceID: instanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{listVMSSMock(vmsss, nil), listVMSSVMMock(nil), vmRestartMock, getInstanceViewMock}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1634,10 +1525,10 @@ func TestRestartOldScalesets(t *testing.T) { } for _, m := range tt.mocks { - m(mockVMSS, mockVMSSVM) + m(mockVMSS, mockVMSSVM, tt.testParams) } - err := d.restartOldScalesets(ctx, tt.resourceGroupName) + err := d.restartOldScalesets(ctx, tt.testParams.resourceGroup) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -1646,15 +1537,11 @@ func TestRestartOldScalesets(t *testing.T) { func TestRestartOldScaleset(t *testing.T) { ctx := context.Background() otherVMSSName := "other-vmss" - rgName := "testRG" + rg := "testRG" gwyVMSSName := gatewayVMSSPrefix + "test" rpVMSSName := rpVMSSPrefix + "test" - instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{ - { - InstanceID: to.StringPtr(instanceID), - }, - } + vmInstanceID := "testID" + vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: &vmInstanceID}} healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ Status: &mgmtcompute.InstanceViewStatus{ @@ -1662,78 +1549,84 @@ func TestRestartOldScaleset(t *testing.T) { }, }, } + genericError := errors.New("generic error") - type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient) - listVMSSFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().List(ctx, rgName, gwyVMSSName, "", "", "").Return( - vms, errors.New("generic error"), - ) + type testParams struct { + resourceGroup string + vmssName string + instanceID string + restartScript string } - listVMSSSuccessful := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().List(ctx, rgName, gomock.Any(), "", "", "").Return( - vms, nil, - ) + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getInstanceViewMock := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().GetInstanceView(gomock.Any(), tp.resourceGroup, tp.vmssName, tp.instanceID).Return(healthyVMSS, nil) } - gatewayRestartFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().RunCommandAndWait(ctx, rgName, gwyVMSSName, instanceID, mgmtcompute.RunCommandInput{ - CommandID: to.StringPtr("RunShellScript"), - Script: &[]string{gatewayRestartScript}, - }).Return( - errors.New("generic error"), - ) - } - rpRestartFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().RunCommandAndWait(ctx, rgName, rpVMSSName, instanceID, mgmtcompute.RunCommandInput{ - CommandID: to.StringPtr("RunShellScript"), - Script: &[]string{rpRestartScript}, - }).Return( - errors.New("generic error"), - ) + listVMSSVMMock := func(returnError error) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().List(ctx, tp.resourceGroup, tp.vmssName, "", "", "").Return(vms, returnError) + } } - restartSuccessful := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().RunCommandAndWait(ctx, rgName, gomock.Any(), instanceID, gomock.Any()).Return(nil) - } - healthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(gomock.Any(), rgName, gomock.Any(), instanceID).Return(healthyVMSS, nil) + vmRestartMock := func(returnError error) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().RunCommandAndWait(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(returnError) + } } + for _, tt := range []struct { - name string - vmssName string - resourceGroupName string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "Return an error if the VMSS is not gateway or RP", - vmssName: otherVMSSName, - wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", + name: "Return an error if the VMSS is not gateway or RP", + testParams: testParams{vmssName: otherVMSSName}, + wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", }, { - name: "list VMSS failed", - vmssName: gwyVMSSName, - resourceGroupName: rgName, - mocks: []mock{listVMSSFailed}, - wantErr: "generic error", + name: "list VMSS failed", + testParams: testParams{ + resourceGroup: rg, + vmssName: gwyVMSSName, + instanceID: vmInstanceID, + }, + mocks: []mock{listVMSSVMMock(genericError)}, + wantErr: "generic error", }, { - name: "gateway restart script failed", - vmssName: gwyVMSSName, - resourceGroupName: rgName, - mocks: []mock{listVMSSSuccessful, gatewayRestartFailed}, - wantErr: "generic error", + name: "gateway restart script failed", + testParams: testParams{ + resourceGroup: rg, + vmssName: gwyVMSSName, + instanceID: vmInstanceID, + restartScript: gatewayRestartScript, + }, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(genericError)}, + wantErr: "generic error", }, { - name: "rp restart script failed", - vmssName: rpVMSSName, - resourceGroupName: rgName, - mocks: []mock{listVMSSSuccessful, rpRestartFailed}, - wantErr: "generic error", + name: "rp restart script failed", + testParams: testParams{ + resourceGroup: rg, + vmssName: rpVMSSName, + instanceID: vmInstanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(genericError)}, + wantErr: "generic error", }, { - name: "restart script passes and wait for readiness is successful", - vmssName: rpVMSSName, - resourceGroupName: rgName, - mocks: []mock{listVMSSSuccessful, restartSuccessful, healthyInstanceView}, + name: "restart script passes and wait for readiness is successful", + testParams: testParams{ + resourceGroup: rg, + vmssName: rpVMSSName, + instanceID: vmInstanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(nil), getInstanceViewMock}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1748,10 +1641,10 @@ func TestRestartOldScaleset(t *testing.T) { } for _, m := range tt.mocks { - m(mockVMSS) + m(mockVMSS, tt.testParams) } - err := d.restartOldScaleset(ctx, tt.vmssName, tt.resourceGroupName) + err := d.restartOldScaleset(ctx, tt.testParams.vmssName, tt.testParams.resourceGroup) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -1759,7 +1652,7 @@ func TestRestartOldScaleset(t *testing.T) { func TestWaitForReadiness(t *testing.T) { ctxTimeout, cancel := context.WithTimeout(context.Background(), 11*time.Second) - vmmssName := "testVMSS" + vmssName := "testVMSS" vmInstanceID := "testVMInstanceID" testRG := "testRG" unhealthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ @@ -1776,44 +1669,48 @@ func TestWaitForReadiness(t *testing.T) { }, }, } - type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient) - unhealthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(ctxTimeout, testRG, vmmssName, vmInstanceID).Return(unhealthyVMSS, nil).AnyTimes() + + type testParams struct { + resourceGroup string + vmssName string + vmInstanceID string + ctx context.Context + cancel context.CancelFunc } - healthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(ctxTimeout, testRG, vmmssName, vmInstanceID).Return(healthyVMSS, nil) + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getInstanceViewMock := func(vm mgmtcompute.VirtualMachineScaleSetVMInstanceView) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().GetInstanceView(tp.ctx, tp.resourceGroup, tp.vmssName, tp.vmInstanceID).Return(vm, nil).AnyTimes() + } } + for _, tt := range []struct { - name string - ctx context.Context - cancel context.CancelFunc - vmssName string - vmInstanceID string - resourceGroupName string - mocks []mock - wantErr string + name string + testParams testParams + mocks []mock + wantErr string }{ { - name: "fail after context times out", - ctx: ctxTimeout, - vmssName: vmmssName, - vmInstanceID: vmInstanceID, - resourceGroupName: testRG, - mocks: []mock{ - unhealthyInstanceView, + name: "fail after context times out", + testParams: testParams{ + resourceGroup: testRG, + vmssName: vmssName, + vmInstanceID: vmInstanceID, + ctx: ctxTimeout, }, + mocks: []mock{getInstanceViewMock(unhealthyVMSS)}, wantErr: "timed out waiting for the condition", }, { - name: "run successfully after confirming healthy status", - ctx: ctxTimeout, - cancel: cancel, - vmssName: vmmssName, - vmInstanceID: vmInstanceID, - resourceGroupName: testRG, - mocks: []mock{ - healthyInstanceView, + name: "run successfully after confirming healthy status", + testParams: testParams{ + resourceGroup: testRG, + vmssName: vmssName, + vmInstanceID: vmInstanceID, + ctx: ctxTimeout, + cancel: cancel, }, + mocks: []mock{getInstanceViewMock(healthyVMSS)}, }, } { t.Run(tt.name, func(t *testing.T) { @@ -1828,11 +1725,11 @@ func TestWaitForReadiness(t *testing.T) { } for _, m := range tt.mocks { - m(mockVMSS) + m(mockVMSS, tt.testParams) } defer cancel() - err := d.waitForReadiness(tt.ctx, tt.resourceGroupName, tt.vmssName, tt.vmInstanceID) + err := d.waitForReadiness(tt.testParams.ctx, tt.testParams.resourceGroup, tt.testParams.vmssName, tt.testParams.vmInstanceID) utilerror.AssertErrorMessage(t, err, tt.wantErr) }) } @@ -1840,7 +1737,7 @@ func TestWaitForReadiness(t *testing.T) { func TestIsVMInstanceHealthy(t *testing.T) { ctx := context.Background() - vmmssName := "testVMSS" + vmssName := "testVMSS" vmInstanceID := "testVMInstanceID" rpRGName := "testRPRG" gatewayRGName := "testGatewayRG" @@ -1858,70 +1755,64 @@ func TestIsVMInstanceHealthy(t *testing.T) { }, }, } + genericError := errors.New("generic error") - type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient) - getRPInstanceViewFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(ctx, rpRGName, vmmssName, vmInstanceID).Return( - unhealthyVMSS, errors.New("generic error"), - ) - } - getGatewayInstanceViewFailed := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(ctx, gatewayRGName, vmmssName, vmInstanceID).Return( - unhealthyVMSS, errors.New("generic error"), - ) - } - unhealthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(ctx, gatewayRGName, vmmssName, vmInstanceID).Return(unhealthyVMSS, nil) + type testParams struct { + resourceGroup string + vmssName string + instanceID string } - healthyInstanceView := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient) { - c.EXPECT().GetInstanceView(ctx, gatewayRGName, vmmssName, vmInstanceID).Return(healthyVMSS, nil) + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getInstanceViewMock := func(vm mgmtcompute.VirtualMachineScaleSetVMInstanceView, returnError error) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().GetInstanceView(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID).Return(vm, returnError).AnyTimes() + } } + for _, tt := range []struct { - name string - vmssName string - vmInstanceID string - resourceGroupName string - mocks []mock - wantBool bool + name string + testParams testParams + mocks []mock + wantBool bool }{ { - name: "return false if GetInstanceView failed for RP resource group", - vmssName: vmmssName, - vmInstanceID: vmInstanceID, - resourceGroupName: rpRGName, - mocks: []mock{ - getRPInstanceViewFailed, + name: "return false if GetInstanceView failed for RP resource group", + testParams: testParams{ + resourceGroup: rpRGName, + vmssName: vmssName, + instanceID: vmInstanceID, }, + mocks: []mock{getInstanceViewMock(healthyVMSS, genericError)}, wantBool: false, }, { - name: "return false if GetInstanceView failed for Gateway resource group", - vmssName: vmmssName, - vmInstanceID: vmInstanceID, - resourceGroupName: gatewayRGName, - mocks: []mock{ - getGatewayInstanceViewFailed, + name: "return false if GetInstanceView failed for Gateway resource group", + testParams: testParams{ + resourceGroup: gatewayRGName, + vmssName: vmssName, + instanceID: vmInstanceID, }, + mocks: []mock{getInstanceViewMock(healthyVMSS, genericError)}, wantBool: false, }, { - name: "return false if GetInstanceView return unhealthy VM", - vmssName: vmmssName, - vmInstanceID: vmInstanceID, - resourceGroupName: gatewayRGName, - mocks: []mock{ - unhealthyInstanceView, + name: "return false if GetInstanceView return unhealthy VM", + testParams: testParams{ + resourceGroup: rpRGName, + vmssName: vmssName, + instanceID: vmInstanceID, }, + mocks: []mock{getInstanceViewMock(unhealthyVMSS, nil)}, wantBool: false, }, { - name: "return true if GetInstanceView return healthy VM", - vmssName: vmmssName, - vmInstanceID: vmInstanceID, - resourceGroupName: gatewayRGName, - mocks: []mock{ - healthyInstanceView, + name: "return true if GetInstanceView return healthy VM", + testParams: testParams{ + resourceGroup: rpRGName, + vmssName: vmssName, + instanceID: vmInstanceID, }, + mocks: []mock{getInstanceViewMock(healthyVMSS, nil)}, wantBool: true, }, } { @@ -1937,10 +1828,10 @@ func TestIsVMInstanceHealthy(t *testing.T) { } for _, m := range tt.mocks { - m(mockVMSS) + m(mockVMSS, tt.testParams) } - got := d.isVMInstanceHealthy(ctx, tt.resourceGroupName, tt.vmssName, tt.vmInstanceID) + got := d.isVMInstanceHealthy(ctx, tt.testParams.resourceGroup, tt.testParams.vmssName, tt.testParams.instanceID) if tt.wantBool != got { t.Errorf("%#v", got) } From ed1657be8495a3cb7291119f91cc92fd64ba4703 Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Wed, 26 Jul 2023 11:14:07 -0400 Subject: [PATCH 8/8] remove variables duplication from predeploy test cases --- pkg/deploy/predeploy_test.go | 386 +++++++++++++---------------------- 1 file changed, 143 insertions(+), 243 deletions(-) diff --git a/pkg/deploy/predeploy_test.go b/pkg/deploy/predeploy_test.go index 981dce01488..d004b5f9b94 100644 --- a/pkg/deploy/predeploy_test.go +++ b/pkg/deploy/predeploy_test.go @@ -31,14 +31,62 @@ import ( utilerror "github.com/Azure/ARO-RP/test/util/error" ) +var ( + instanceID = "testID" + rgName = "testRG" + location = "testLocation" + globalRGName = "testRG-global" + subscriptionRGName = "testRG-subscription" + notExistingFileName = "testFile" + existingFileName = generator.FileGatewayProductionPredeploy + existingFileDeploymentName = strings.TrimSuffix(existingFileName, ".json") + secretExists = "secretExists" + noSecretExists = "noSecretExists" + + errGeneric = errors.New("generic error") + deploymentFailedError = &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{{}}, + } + deploymentNotFoundError = autorest.DetailedError{ + Original: &azure.RequestError{ + ServiceError: &azure.ServiceError{ + Code: "DeploymentNotFound", + Details: []map[string]interface{}{{}}, + }, + }, + } + + healthyVMSS = mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{Code: to.StringPtr("HealthState/healthy")}, + }, + } + unhealthyVMSS = mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/unhealthy"), + }, + }, + } + + nowUnixTime = date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) + newSecretBundle = azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, + } + + secretItems = []azkeyvault.SecretItem{{ID: to.StringPtr("secretExists")}} + + vms = []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: to.StringPtr(instanceID)}} +) + func TestPreDeploy(t *testing.T) { ctx := context.Background() - location := "testLocation" - subscriptionRgName := "testRG-subscription" - globalRgName := "testRG-global" rpRgName := "testRG-aro-rp" gatewayRgName := "testRG-gwy" overrideLocation := "overrideTestLocation" + vmssName := rpVMSSPrefix + "test" + group := mgmtfeatures.ResourceGroup{ Location: &location, } @@ -47,46 +95,15 @@ func TestPreDeploy(t *testing.T) { UserAssignedIdentityProperties: &mgmtmsi.UserAssignedIdentityProperties{PrincipalID: &fakeMSIObjectId}, } deployment := mgmtfeatures.DeploymentExtended{} - vmssName := rpVMSSPrefix + "test" - nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) - newSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, - } vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: &vmssName}} oneMissingSecrets := []string{env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} oneMissingSecretItems := []azkeyvault.SecretItem{} for _, secret := range oneMissingSecrets { oneMissingSecretItems = append(oneMissingSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) } - instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: to.StringPtr(instanceID)}} - healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/healthy"), - }, - }, - } - deploymentNotFoundError := autorest.DetailedError{ - Original: &azure.RequestError{ - ServiceError: &azure.ServiceError{ - Code: "DeploymentNotFound", - Details: []map[string]interface{}{ - {}, - }, - }, - }, - } - deploymentFailedError := &azure.ServiceError{ - Code: "DeploymentFailed", - Details: []map[string]interface{}{ - {}, - }, - } - genericError := errors.New("generic error") type resourceGroups struct { - subscriptionRgName string + subscriptionRGName string globalResourceGroup string rpResourceGroupName string gatewayResourceGroupName string @@ -166,7 +183,7 @@ func TestPreDeploy(t *testing.T) { location: location, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(errGeneric), }, wantErr: "generic error", }, @@ -175,11 +192,11 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, + subscriptionRGName: subscriptionRGName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, errGeneric), }, wantErr: "generic error", }, @@ -188,12 +205,12 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, errGeneric), }, wantErr: "generic error", }, @@ -202,13 +219,13 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, errGeneric), }, wantErr: "generic error", }, @@ -217,14 +234,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, errGeneric), }, wantErr: "generic error", }, @@ -233,14 +250,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, errGeneric), }, wantErr: "generic error", }, @@ -249,14 +266,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, errGeneric), }, wantErr: "generic error", }, @@ -265,14 +282,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, errGeneric), }, wantErr: "generic error", }, @@ -281,14 +298,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, errGeneric), }, wantErr: "generic error", }, @@ -297,14 +314,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, deploymentFailedError), createOrUpdateAndWaitMock(globalRgName, deploymentFailedError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, deploymentFailedError), createOrUpdateAndWaitMock(globalRGName, deploymentFailedError), }, wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, }, @@ -313,15 +330,15 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, overrideLocation: overrideLocation, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), createOrUpdateAndWaitMock(globalRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), createOrUpdateAndWaitMock(globalRGName, errGeneric), }, wantErr: "generic error", }, @@ -330,14 +347,14 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, errGeneric), }, wantErr: "generic error", }, @@ -346,15 +363,15 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, overrideLocation: location, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, errGeneric), }, wantErr: "generic error", }, @@ -363,8 +380,8 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, @@ -372,7 +389,7 @@ func TestPreDeploy(t *testing.T) { acrReplicaDisabled: true, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, errGeneric), }, wantErr: "generic error", }, @@ -381,8 +398,8 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, @@ -390,7 +407,7 @@ func TestPreDeploy(t *testing.T) { acrReplicaDisabled: true, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, errGeneric), }, wantErr: "generic error", }, @@ -399,8 +416,8 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, @@ -408,7 +425,7 @@ func TestPreDeploy(t *testing.T) { acrReplicaDisabled: true, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, genericError), + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, errGeneric), }, wantErr: "generic error", }, @@ -417,8 +434,8 @@ func TestPreDeploy(t *testing.T) { testParams: testParams{ location: location, resourceGroups: resourceGroups{ - subscriptionRgName: subscriptionRgName, - globalResourceGroup: globalRgName, + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, rpResourceGroupName: rpRgName, gatewayResourceGroupName: gatewayRgName, }, @@ -429,7 +446,7 @@ func TestPreDeploy(t *testing.T) { restartScript: rpRestartScript, }, mocks: []mock{ - createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRgName, group, nil), createOrUpdateMock(globalRgName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRgName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, }, }, } { @@ -455,7 +472,7 @@ func TestPreDeploy(t *testing.T) { Configuration: &Configuration{ GlobalResourceGroupLocation: &tt.testParams.location, SubscriptionResourceGroupLocation: &tt.testParams.location, - SubscriptionResourceGroupName: &tt.testParams.resourceGroups.subscriptionRgName, + SubscriptionResourceGroupName: &tt.testParams.resourceGroups.subscriptionRGName, GlobalResourceGroupName: &tt.testParams.resourceGroups.globalResourceGroup, ACRLocationOverride: &tt.testParams.overrideLocation, ACRReplicaDisabled: &tt.testParams.acrReplicaDisabled, @@ -482,14 +499,6 @@ func TestPreDeploy(t *testing.T) { func TestDeployRPGlobalSubscription(t *testing.T) { ctx := context.Background() - location := "locationTest" - deploymentFailedError := &azure.ServiceError{ - Code: "DeploymentFailed", - Details: []map[string]interface{}{ - {}, - }, - } - genericError := errors.New("generic error") type testParams struct { location string @@ -510,7 +519,7 @@ func TestDeployRPGlobalSubscription(t *testing.T) { { name: "Don't continue if deployment fails with error other than DeploymentFailed", testParams: testParams{location: location}, - mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(genericError)}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(errGeneric)}, wantErr: "generic error", }, { @@ -554,9 +563,6 @@ func TestDeployRPGlobalSubscription(t *testing.T) { func TestDeployRPSubscription(t *testing.T) { ctx := context.Background() - location := "locationTest" - subscriptionRGName := "rgTest" - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -581,7 +587,7 @@ func TestDeployRPSubscription(t *testing.T) { location: location, resourceGroup: subscriptionRGName, }, - mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, wantErr: "generic error", }, { @@ -622,11 +628,6 @@ func TestDeployRPSubscription(t *testing.T) { func TestDeployManagedIdentity(t *testing.T) { ctx := context.Background() - rgName := "rgTest" - existingFileName := generator.FileGatewayProductionPredeploy - deploymentName := strings.TrimSuffix(existingFileName, ".json") - notExistingFileName := "testFile" - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -657,17 +658,17 @@ func TestDeployManagedIdentity(t *testing.T) { name: "Don't continue if deployment fails", testParams: testParams{ deploymentFileName: existingFileName, - deploymentName: deploymentName, + deploymentName: existingFileDeploymentName, resourceGroup: rgName, }, - mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, wantErr: "generic error", }, { name: "Pass successfully when deployment is successful", testParams: testParams{ deploymentFileName: existingFileName, - deploymentName: deploymentName, + deploymentName: existingFileDeploymentName, resourceGroup: rgName, }, mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, @@ -699,17 +700,8 @@ func TestDeployManagedIdentity(t *testing.T) { func TestDeployRPGlobal(t *testing.T) { ctx := context.Background() - location := "locationTest" - globalRGName := "globalRGTest" rpSPID := "rpSPIDTest" gwySPID := "gwySPIDTest" - deploymentFailedError := &azure.ServiceError{ - Code: "DeploymentFailed", - Details: []map[string]interface{}{ - {}, - }, - } - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -738,7 +730,7 @@ func TestDeployRPGlobal(t *testing.T) { rpSPID: rpSPID, gwySPID: gwySPID, }, - mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, wantErr: "generic error", }, { @@ -792,9 +784,6 @@ func TestDeployRPGlobal(t *testing.T) { func TestDeployRPGlobalACRReplication(t *testing.T) { ctx := context.Background() - globalRGName := "globalRGTest" - location := "testLocation" - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -819,7 +808,7 @@ func TestDeployRPGlobalACRReplication(t *testing.T) { location: location, resourceGroup: globalRGName, }, - mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, wantErr: "generic error", }, { @@ -860,13 +849,8 @@ func TestDeployRPGlobalACRReplication(t *testing.T) { func TestDeployPreDeploy(t *testing.T) { ctx := context.Background() - rgName := "testRG" - existingFileName := generator.FileGatewayProductionPredeploy - deploymentName := strings.TrimSuffix(existingFileName, ".json") - notExistingFileName := "testFile" spIDName := "testSPIDName" spID := "testSPID" - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -904,11 +888,11 @@ func TestDeployPreDeploy(t *testing.T) { testParams: testParams{ resourceGroup: rgName, deploymentFileName: existingFileName, - deploymentName: deploymentName, + deploymentName: existingFileDeploymentName, spIDName: spIDName, spID: spID, }, - mocks: []mock{CreateOrUpdateAndWaitMock(genericError)}, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, wantErr: "generic error", }, { @@ -916,7 +900,7 @@ func TestDeployPreDeploy(t *testing.T) { testParams: testParams{ resourceGroup: rgName, deploymentFileName: existingFileName, - deploymentName: deploymentName, + deploymentName: existingFileDeploymentName, spIDName: spIDName, spID: spID, }, @@ -951,11 +935,6 @@ func TestDeployPreDeploy(t *testing.T) { func TestConfigureServiceSecrets(t *testing.T) { ctx := context.Background() vmssName := rpVMSSPrefix + "test" - rgName := "rgTest" - nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) - newSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, - } vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: to.StringPtr(vmssName)}} oneMissingSecrets := []string{env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} oneMissingSecretItems := []azkeyvault.SecretItem{} @@ -967,14 +946,6 @@ func TestConfigureServiceSecrets(t *testing.T) { for _, secret := range allSecrets { allSecretItems = append(allSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) } - instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: to.StringPtr(instanceID)}} - healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{Code: to.StringPtr("HealthState/healthy")}, - }, - } - genericError := errors.New("generic error") type testParams struct { vmssName string @@ -1022,28 +993,28 @@ func TestConfigureServiceSecrets(t *testing.T) { { name: "return error if ensureAndRotateSecret fails", mocks: []mock{ - getSecretsMock(allSecretItems, genericError), + getSecretsMock(allSecretItems, errGeneric), }, wantErr: "generic error", }, { name: "return error if ensureAndRotateSecret passes without rotating any secret but ensureSecret fails", mocks: []mock{ - getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, genericError), + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, errGeneric), }, wantErr: "generic error", }, { name: "return error if ensureAndRotateSecret passes with rotating a missing secret but ensureSecret fails", mocks: []mock{ - getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, genericError), + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, errGeneric), }, wantErr: "generic error", }, { name: "return error if ensureAndRotateSecret, ensureSecret passes without rotating a secret but ensureSecretKey fails", mocks: []mock{ - getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, genericError), + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, errGeneric), }, wantErr: "generic error", }, @@ -1061,7 +1032,7 @@ func TestConfigureServiceSecrets(t *testing.T) { resourceGroup: rgName, }, mocks: []mock{ - getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), vmssListMock(genericError), + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), vmssListMock(errGeneric), }, wantErr: "generic error", }, @@ -1110,18 +1081,10 @@ func TestConfigureServiceSecrets(t *testing.T) { func TestEnsureAndRotateSecret(t *testing.T) { ctx := context.Background() - secretExists := "secretExists" - noSecretExists := "noSecretExists" - secretItems := []azkeyvault.SecretItem{{ID: &secretExists}} - nowUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) oldUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Add(-rotateSecretAfter).Unix())) - newSecretBundle := azkeyvault.SecretBundle{ - Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, - } oldSecretBundle := azkeyvault.SecretBundle{ Attributes: &azkeyvault.SecretAttributes{Created: &oldUnixTime}, } - genericError := errors.New("generic error") type testParams struct { secretToFind string @@ -1153,14 +1116,14 @@ func TestEnsureAndRotateSecret(t *testing.T) { { name: "return false and error if GetSecrets fails", testParams: testParams{secretToFind: secretExists}, - mocks: []mock{getSecretsMock(genericError)}, + mocks: []mock{getSecretsMock(errGeneric)}, wantBool: false, wantErr: "generic error", }, { name: "return false and error if GetSecrets passes but GetSecret fails for the found secret", testParams: testParams{secretToFind: secretExists}, - mocks: []mock{getSecretsMock(nil), getSecretMock(newSecretBundle, genericError)}, + mocks: []mock{getSecretsMock(nil), getSecretMock(newSecretBundle, errGeneric)}, wantBool: false, wantErr: "generic error", }, @@ -1173,7 +1136,7 @@ func TestEnsureAndRotateSecret(t *testing.T) { { name: "return true and error if GetSecrets & GetSecret passes and the secret is old but new secret creation fails", testParams: testParams{secretToFind: secretExists}, - mocks: []mock{getSecretsMock(nil), getSecretMock(oldSecretBundle, nil), setSecretMock(genericError)}, + mocks: []mock{getSecretsMock(nil), getSecretMock(oldSecretBundle, nil), setSecretMock(errGeneric)}, wantBool: true, wantErr: "generic error", }, @@ -1215,10 +1178,6 @@ func TestEnsureAndRotateSecret(t *testing.T) { func TestEnsureSecret(t *testing.T) { ctx := context.Background() - secretExists := "secretExists" - noSecretExists := "noSecretExists" - secretItems := []azkeyvault.SecretItem{{ID: &secretExists}} - genericError := errors.New("generic error") type testParams struct { secretToFind string @@ -1245,7 +1204,7 @@ func TestEnsureSecret(t *testing.T) { { name: "return false and error if GetSecrets fails", testParams: testParams{secretToFind: secretExists}, - mocks: []mock{getSecretsMock(genericError)}, + mocks: []mock{getSecretsMock(errGeneric)}, wantBool: false, wantErr: "generic error", }, @@ -1258,7 +1217,7 @@ func TestEnsureSecret(t *testing.T) { { name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", testParams: testParams{secretToFind: noSecretExists}, - mocks: []mock{getSecretsMock(nil), setSecretMock(genericError)}, + mocks: []mock{getSecretsMock(nil), setSecretMock(errGeneric)}, wantBool: true, wantErr: "generic error", }, @@ -1294,8 +1253,6 @@ func TestEnsureSecret(t *testing.T) { func TestCreateSecret(t *testing.T) { ctx := context.Background() - noSecretExists := "noSecretExists" - genericError := errors.New("generic error") type testParams struct { secretToCreate string @@ -1318,7 +1275,7 @@ func TestCreateSecret(t *testing.T) { testParams: testParams{ secretToCreate: noSecretExists, }, - mocks: []mock{setSecretMock(genericError)}, + mocks: []mock{setSecretMock(errGeneric)}, wantErr: "generic error", }, { @@ -1351,10 +1308,6 @@ func TestCreateSecret(t *testing.T) { func TestEnsureSecretKey(t *testing.T) { ctx := context.Background() - secretExists := "secretExists" - noSecretExists := "noSecretExists" - secretItems := []azkeyvault.SecretItem{{ID: &secretExists}} - genericError := errors.New("generic error") type testParams struct { secretToFind string @@ -1381,7 +1334,7 @@ func TestEnsureSecretKey(t *testing.T) { { name: "return false and error if GetSecrets fails", testParams: testParams{secretToFind: secretExists}, - mocks: []mock{getSecretsMock(genericError)}, + mocks: []mock{getSecretsMock(errGeneric)}, wantBool: false, wantErr: "generic error", }, @@ -1394,7 +1347,7 @@ func TestEnsureSecretKey(t *testing.T) { { name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", testParams: testParams{secretToFind: noSecretExists}, - mocks: []mock{getSecretsMock(nil), setSecretMock(genericError)}, + mocks: []mock{getSecretsMock(nil), setSecretMock(errGeneric)}, wantBool: true, wantErr: "generic error", }, @@ -1430,21 +1383,10 @@ func TestEnsureSecretKey(t *testing.T) { func TestRestartOldScalesets(t *testing.T) { ctx := context.Background() - rgName := "testRG" rpVMSSName := rpVMSSPrefix + "test" invalidVMSSName := "other-vmss" invalidVMSSs := []mgmtcompute.VirtualMachineScaleSet{{Name: &invalidVMSSName}} vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: &rpVMSSName}} - instanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: &instanceID}} - healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/healthy"), - }, - }, - } - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -1482,7 +1424,7 @@ func TestRestartOldScalesets(t *testing.T) { { name: "Don't continue if vmss list fails", testParams: testParams{resourceGroup: rgName}, - mocks: []mock{listVMSSMock(vmsss, genericError)}, + mocks: []mock{listVMSSMock(vmsss, errGeneric)}, wantErr: "generic error", }, { @@ -1497,7 +1439,7 @@ func TestRestartOldScalesets(t *testing.T) { resourceGroup: rgName, vmssName: rpVMSSName, }, - mocks: []mock{listVMSSMock(vmsss, nil), listVMSSVMMock(genericError)}, + mocks: []mock{listVMSSMock(vmsss, nil), listVMSSVMMock(errGeneric)}, wantErr: "generic error", }, { @@ -1537,19 +1479,8 @@ func TestRestartOldScalesets(t *testing.T) { func TestRestartOldScaleset(t *testing.T) { ctx := context.Background() otherVMSSName := "other-vmss" - rg := "testRG" gwyVMSSName := gatewayVMSSPrefix + "test" rpVMSSName := rpVMSSPrefix + "test" - vmInstanceID := "testID" - vms := []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: &vmInstanceID}} - healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/healthy"), - }, - }, - } - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -1589,41 +1520,41 @@ func TestRestartOldScaleset(t *testing.T) { { name: "list VMSS failed", testParams: testParams{ - resourceGroup: rg, + resourceGroup: rgName, vmssName: gwyVMSSName, - instanceID: vmInstanceID, + instanceID: instanceID, }, - mocks: []mock{listVMSSVMMock(genericError)}, + mocks: []mock{listVMSSVMMock(errGeneric)}, wantErr: "generic error", }, { name: "gateway restart script failed", testParams: testParams{ - resourceGroup: rg, + resourceGroup: rgName, vmssName: gwyVMSSName, - instanceID: vmInstanceID, + instanceID: instanceID, restartScript: gatewayRestartScript, }, - mocks: []mock{listVMSSVMMock(nil), vmRestartMock(genericError)}, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(errGeneric)}, wantErr: "generic error", }, { name: "rp restart script failed", testParams: testParams{ - resourceGroup: rg, + resourceGroup: rgName, vmssName: rpVMSSName, - instanceID: vmInstanceID, + instanceID: instanceID, restartScript: rpRestartScript, }, - mocks: []mock{listVMSSVMMock(nil), vmRestartMock(genericError)}, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(errGeneric)}, wantErr: "generic error", }, { name: "restart script passes and wait for readiness is successful", testParams: testParams{ - resourceGroup: rg, + resourceGroup: rgName, vmssName: rpVMSSName, - instanceID: vmInstanceID, + instanceID: instanceID, restartScript: rpRestartScript, }, mocks: []mock{listVMSSVMMock(nil), vmRestartMock(nil), getInstanceViewMock}, @@ -1653,22 +1584,6 @@ func TestRestartOldScaleset(t *testing.T) { func TestWaitForReadiness(t *testing.T) { ctxTimeout, cancel := context.WithTimeout(context.Background(), 11*time.Second) vmssName := "testVMSS" - vmInstanceID := "testVMInstanceID" - testRG := "testRG" - unhealthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/unhealthy"), - }, - }, - } - healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/healthy"), - }, - }, - } type testParams struct { resourceGroup string @@ -1693,9 +1608,9 @@ func TestWaitForReadiness(t *testing.T) { { name: "fail after context times out", testParams: testParams{ - resourceGroup: testRG, + resourceGroup: rgName, vmssName: vmssName, - vmInstanceID: vmInstanceID, + vmInstanceID: instanceID, ctx: ctxTimeout, }, mocks: []mock{getInstanceViewMock(unhealthyVMSS)}, @@ -1704,9 +1619,9 @@ func TestWaitForReadiness(t *testing.T) { { name: "run successfully after confirming healthy status", testParams: testParams{ - resourceGroup: testRG, + resourceGroup: rgName, vmssName: vmssName, - vmInstanceID: vmInstanceID, + vmInstanceID: instanceID, ctx: ctxTimeout, cancel: cancel, }, @@ -1741,21 +1656,6 @@ func TestIsVMInstanceHealthy(t *testing.T) { vmInstanceID := "testVMInstanceID" rpRGName := "testRPRG" gatewayRGName := "testGatewayRG" - unhealthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/unhealthy"), - }, - }, - } - healthyVMSS := mgmtcompute.VirtualMachineScaleSetVMInstanceView{ - VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ - Status: &mgmtcompute.InstanceViewStatus{ - Code: to.StringPtr("HealthState/healthy"), - }, - }, - } - genericError := errors.New("generic error") type testParams struct { resourceGroup string @@ -1782,7 +1682,7 @@ func TestIsVMInstanceHealthy(t *testing.T) { vmssName: vmssName, instanceID: vmInstanceID, }, - mocks: []mock{getInstanceViewMock(healthyVMSS, genericError)}, + mocks: []mock{getInstanceViewMock(healthyVMSS, errGeneric)}, wantBool: false, }, { @@ -1792,7 +1692,7 @@ func TestIsVMInstanceHealthy(t *testing.T) { vmssName: vmssName, instanceID: vmInstanceID, }, - mocks: []mock{getInstanceViewMock(healthyVMSS, genericError)}, + mocks: []mock{getInstanceViewMock(healthyVMSS, errGeneric)}, wantBool: false, }, {