Skip to content

Commit

Permalink
update predeploy to restart old VMSS when service secrets rotated (#2946
Browse files Browse the repository at this point in the history
)

* update predeploy to restart old VMSS when service secrets rotated
* update scalesetVMSS conditions check for restart at RP predeploy
* add vmss health check after vmss restart
* nit changes related to logging
* remove concurrent rp service restarts
* Add unit test cases for RP predeploy function
* refactor predeploy.go unit test cases
* remove variables duplication from predeploy test cases
  • Loading branch information
rajdeepc2792 authored Jul 28, 2023
1 parent 06cfba5 commit f6129d9
Show file tree
Hide file tree
Showing 6 changed files with 1,929 additions and 28 deletions.
148 changes: 128 additions & 20 deletions pkg/deploy/predeploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,34 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"path/filepath"
"strings"
"time"

mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute"
azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault"
mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/to"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/deploy/assets"
"github.com/Azure/ARO-RP/pkg/deploy/generator"
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/util/arm"
"github.com/Azure/ARO-RP/pkg/util/keyvault"
)

// Rotate the secret on every deploy of the RP if the most recent
// secret is greater than 7 days old
const rotateSecretAfter = time.Hour * 168
const (
// Rotate the secret on every deploy of the RP if the most recent
// secret is greater than 7 days old
rotateSecretAfter = time.Hour * 24 * 7
rpRestartScript = "systemctl restart aro-rp"
gatewayRestartScript = "systemctl restart aro-gateway"
)

// PreDeploy deploys managed identity, NSGs and keyvaults, needed for main
// deployment
Expand Down Expand Up @@ -352,6 +361,7 @@ func (d *deployer) deployPreDeploy(ctx context.Context, resourceGroupName, deplo
}

func (d *deployer) configureServiceSecrets(ctx context.Context) error {
isRotated := false
for _, s := range []struct {
kv keyvault.Manager
secretName string
Expand All @@ -361,7 +371,8 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error {
{d.serviceKeyvault, env.FrontendEncryptionSecretV2Name, 64},
{d.portalKeyvault, env.PortalServerSessionKeySecretName, 32},
} {
err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len)
isNew, err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len)
isRotated = isNew || isRotated
if err != nil {
return err
}
Expand All @@ -376,54 +387,71 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error {
{d.serviceKeyvault, env.EncryptionSecretName, 32},
{d.serviceKeyvault, env.FrontendEncryptionSecretName, 32},
} {
err := d.ensureSecret(ctx, s.kv, s.secretName, s.len)
isNew, err := d.ensureSecret(ctx, s.kv, s.secretName, s.len)
isRotated = isNew || isRotated
if err != nil {
return err
}
}

return d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName)
isNew, err := d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName)
isRotated = isNew || isRotated
if err != nil {
return err
}

if isRotated {
err = d.restartOldScalesets(ctx, d.config.GatewayResourceGroupName)
if err != nil {
return err
}
err = d.restartOldScalesets(ctx, d.config.RPResourceGroupName)
if err != nil {
return err
}
}
return nil
}

func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error {
func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) {
existingSecrets, err := kv.GetSecrets(ctx)
if err != nil {
return err
return false, err
}

for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
latestVersion, err := kv.GetSecret(ctx, secretName)
if err != nil {
return err
return false, err
}

updatedTime := time.Unix(0, latestVersion.Attributes.Created.Duration().Nanoseconds()).Add(rotateSecretAfter)

// do not create a secret if rotateSecretAfter time has
// not elapsed since the secret version's creation timestamp
if time.Now().Before(updatedTime) {
return nil
return false, nil
}
}
}

return d.createSecret(ctx, kv, secretName, len)
return true, d.createSecret(ctx, kv, secretName, len)
}

func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error {
func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) {
existingSecrets, err := kv.GetSecrets(ctx)
if err != nil {
return err
return false, err
}

for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
return nil
return false, nil
}
}

return d.createSecret(ctx, kv, secretName, len)
return true, d.createSecret(ctx, kv, secretName, len)
}

func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error {
Expand All @@ -439,25 +467,105 @@ func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secret
})
}

func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) error {
func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) (isNew bool, err error) {
existingSecrets, err := kv.GetSecrets(ctx)
if err != nil {
return err
return false, err
}

for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
return nil
return false, nil
}
}

key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return false, err
}

d.log.Infof("setting %s", secretName)
return kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{
return true, kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{
Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))),
})
}

func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName string) error {
scalesets, err := d.vmss.List(ctx, resourceGroupName)
if err != nil {
return err
}

for _, vmss := range scalesets {
err = d.restartOldScaleset(ctx, *vmss.Name, resourceGroupName)
if err != nil {
return err
}
}

return nil
}

func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, resourceGroupName string) error {
var restartScript string
switch {
case strings.HasPrefix(vmssName, gatewayVMSSPrefix):
restartScript = gatewayRestartScript
case strings.HasPrefix(vmssName, rpVMSSPrefix):
restartScript = rpRestartScript
default:
return &api.CloudError{
StatusCode: http.StatusBadRequest,
CloudErrorBody: &api.CloudErrorBody{
Code: api.CloudErrorCodeInvalidResource,
Message: fmt.Sprintf("provided vmss %s does not match RP or gateway prefix",
vmssName,
),
},
}
}

scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "")
if err != nil {
return err
}

for _, vm := range scalesetVMs {
d.log.Printf("waiting for restart script to complete on older vmss %s, instance %s", vmssName, *vm.InstanceID)
err = d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, *vm.InstanceID, mgmtcompute.RunCommandInput{
CommandID: to.StringPtr("RunShellScript"),
Script: &[]string{restartScript},
})

if err != nil {
return err
}

// wait for load balancer probe to change the vm health status
time.Sleep(30 * time.Second)
timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour)
defer cancel()
err = d.waitForReadiness(timeoutCtx, resourceGroupName, vmssName, *vm.InstanceID)
if err != nil {
return err
}
}

return nil
}

func (d *deployer) waitForReadiness(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) error {
return wait.PollImmediateUntil(10*time.Second, func() (bool, error) {
return d.isVMInstanceHealthy(ctx, resourceGroupName, vmssName, vmInstanceID), nil
}, ctx.Done())
}

func (d *deployer) isVMInstanceHealthy(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) bool {
r, err := d.vmssvms.GetInstanceView(ctx, resourceGroupName, vmssName, vmInstanceID)
instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy"
if err != nil || instanceUnhealthy {
d.log.Printf("instance %s is unhealthy", vmInstanceID)
return false
}
return true
}
Loading

0 comments on commit f6129d9

Please sign in to comment.