From b24b0771685e60035c8e6252aadacc92feb042af Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Wed, 7 Jun 2023 09:29:06 -0400 Subject: [PATCH] update predeploy to restart old VMSS when service secrets rotated --- pkg/deploy/predeploy.go | 98 ++++++++++++++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 17 deletions(-) diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 4a3097a783b..42c504bab46 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -14,6 +14,7 @@ import ( "strings" "time" + mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute" azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault" mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features" "github.com/Azure/go-autorest/autorest/azure" @@ -352,6 +353,7 @@ func (d *deployer) deployPreDeploy(ctx context.Context, resourceGroupName, deplo } func (d *deployer) configureServiceSecrets(ctx context.Context) error { + isRotated := false for _, s := range []struct { kv keyvault.Manager secretName string @@ -361,7 +363,8 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { {d.serviceKeyvault, env.FrontendEncryptionSecretV2Name, 64}, {d.portalKeyvault, env.PortalServerSessionKeySecretName, 32}, } { - err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len) + isNew, err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len) + isRotated = isNew || isRotated if err != nil { return err } @@ -376,26 +379,43 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { {d.serviceKeyvault, env.EncryptionSecretName, 32}, {d.serviceKeyvault, env.FrontendEncryptionSecretName, 32}, } { - err := d.ensureSecret(ctx, s.kv, s.secretName, s.len) + isNew, err := d.ensureSecret(ctx, s.kv, s.secretName, s.len) + isRotated = isNew || isRotated if err != nil { return err } } - return d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName) + isNew, err := d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName) + isRotated = isNew || isRotated + if err != nil { + return err + } + + if isRotated { + err = d.restartOldScalesets(ctx, "systemctl restart aro-gateway", d.config.GatewayResourceGroupName) + if err != nil { + return err + } + err = d.restartOldScalesets(ctx, "systemctl restart aro-rp", d.config.RPResourceGroupName) + if err != nil { + return err + } + } + return nil } -func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { +func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { latestVersion, err := kv.GetSecret(ctx, secretName) if err != nil { - return err + return false, err } updatedTime := time.Unix(0, latestVersion.Attributes.Created.Duration().Nanoseconds()).Add(rotateSecretAfter) @@ -403,27 +423,27 @@ func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manage // do not create a secret if rotateSecretAfter time has // not elapsed since the secret version's creation timestamp if time.Now().Before(updatedTime) { - return nil + return false, nil } } } - return d.createSecret(ctx, kv, secretName, len) + return true, d.createSecret(ctx, kv, secretName, len) } -func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { +func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { - return nil + return false, nil } } - return d.createSecret(ctx, kv, secretName, len) + return true, d.createSecret(ctx, kv, secretName, len) } func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { @@ -439,25 +459,69 @@ func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secret }) } -func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) error { +func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { - return nil + return false, nil } } key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return false, err } d.log.Infof("setting %s", secretName) - return kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{ + return true, kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{ Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))), }) } + +func (d *deployer) restartOldScalesets(ctx context.Context, script string, resourceGroupName string) error { + d.log.Print("restarting old scalesets") + scalesets, err := d.vmss.List(ctx, resourceGroupName) + if err != nil { + return err + } + + for _, vmss := range scalesets { + err = d.restartOldScaleset(ctx, *vmss.Name, script, resourceGroupName) + if err != nil { + return err + } + } + + return nil +} + +func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, script string, resourceGroupName string) error { + scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "") + if err != nil { + return err + } + + d.log.Printf("restarting scaleset %s", vmssName) + errors := make(chan error, len(scalesetVMs)) + for _, vm := range scalesetVMs { + go func(id string) { + errors <- d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, id, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{script}, + }) + }(*vm.InstanceID) + } + + d.log.Print("waiting for instances to restart") + for range scalesetVMs { + err := <-errors + if err != nil { + return err + } + } + return nil +}