Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update predeploy to restart old VMSS when service secrets rotated #2946

Merged
148 changes: 128 additions & 20 deletions pkg/deploy/predeploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,34 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"path/filepath"
"strings"
"time"

mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute"
azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault"
mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/to"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/deploy/assets"
"github.com/Azure/ARO-RP/pkg/deploy/generator"
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/util/arm"
"github.com/Azure/ARO-RP/pkg/util/keyvault"
)

// Rotate the secret on every deploy of the RP if the most recent
// secret is greater than 7 days old
const rotateSecretAfter = time.Hour * 168
const (
// Rotate the secret on every deploy of the RP if the most recent
// secret is greater than 7 days old
rotateSecretAfter = time.Hour * 24 * 7
rpRestartScript = "systemctl restart aro-rp"
gatewayRestartScript = "systemctl restart aro-gateway"
)

// PreDeploy deploys managed identity, NSGs and keyvaults, needed for main
// deployment
Expand Down Expand Up @@ -352,6 +361,7 @@ func (d *deployer) deployPreDeploy(ctx context.Context, resourceGroupName, deplo
}

func (d *deployer) configureServiceSecrets(ctx context.Context) error {
isRotated := false
for _, s := range []struct {
kv keyvault.Manager
secretName string
Expand All @@ -361,7 +371,8 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error {
{d.serviceKeyvault, env.FrontendEncryptionSecretV2Name, 64},
{d.portalKeyvault, env.PortalServerSessionKeySecretName, 32},
} {
err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len)
isNew, err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len)
isRotated = isNew || isRotated
if err != nil {
return err
}
Expand All @@ -376,54 +387,71 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error {
{d.serviceKeyvault, env.EncryptionSecretName, 32},
{d.serviceKeyvault, env.FrontendEncryptionSecretName, 32},
} {
err := d.ensureSecret(ctx, s.kv, s.secretName, s.len)
isNew, err := d.ensureSecret(ctx, s.kv, s.secretName, s.len)
isRotated = isNew || isRotated
if err != nil {
return err
}
}

return d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName)
isNew, err := d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName)
isRotated = isNew || isRotated
if err != nil {
return err
}

if isRotated {
err = d.restartOldScalesets(ctx, d.config.GatewayResourceGroupName)
if err != nil {
return err
}
err = d.restartOldScalesets(ctx, d.config.RPResourceGroupName)
if err != nil {
return err
}
}
return nil
}

func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error {
func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) {
existingSecrets, err := kv.GetSecrets(ctx)
if err != nil {
return err
return false, err
}

for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
latestVersion, err := kv.GetSecret(ctx, secretName)
if err != nil {
return err
return false, err
}

updatedTime := time.Unix(0, latestVersion.Attributes.Created.Duration().Nanoseconds()).Add(rotateSecretAfter)

// do not create a secret if rotateSecretAfter time has
// not elapsed since the secret version's creation timestamp
if time.Now().Before(updatedTime) {
return nil
return false, nil
}
}
}

return d.createSecret(ctx, kv, secretName, len)
return true, d.createSecret(ctx, kv, secretName, len)
}

func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error {
func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) {
existingSecrets, err := kv.GetSecrets(ctx)
if err != nil {
return err
return false, err
}

for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
return nil
return false, nil
}
}

return d.createSecret(ctx, kv, secretName, len)
return true, d.createSecret(ctx, kv, secretName, len)
}

func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error {
Expand All @@ -439,25 +467,105 @@ func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secret
})
}

func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) error {
func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) (isNew bool, err error) {
existingSecrets, err := kv.GetSecrets(ctx)
if err != nil {
return err
return false, err
}

for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
return nil
return false, nil
}
}

key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return false, err
}

d.log.Infof("setting %s", secretName)
return kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{
return true, kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{
Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))),
})
}

func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName string) error {
scalesets, err := d.vmss.List(ctx, resourceGroupName)
if err != nil {
return err
}

for _, vmss := range scalesets {
err = d.restartOldScaleset(ctx, *vmss.Name, resourceGroupName)
if err != nil {
return err
}
}

return nil
}

func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, resourceGroupName string) error {
var restartScript string
switch {
case strings.HasPrefix(vmssName, gatewayVMSSPrefix):
restartScript = gatewayRestartScript
case strings.HasPrefix(vmssName, rpVMSSPrefix):
restartScript = rpRestartScript
default:
return &api.CloudError{
StatusCode: http.StatusBadRequest,
CloudErrorBody: &api.CloudErrorBody{
Code: api.CloudErrorCodeInvalidResource,
Message: fmt.Sprintf("provided vmss %s does not match RP or gateway prefix",
vmssName,
),
},
}
}

scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "")
if err != nil {
return err
}

for _, vm := range scalesetVMs {
d.log.Printf("waiting for restart script to complete on older vmss %s, instance %s", vmssName, *vm.InstanceID)
err = d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, *vm.InstanceID, mgmtcompute.RunCommandInput{
CommandID: to.StringPtr("RunShellScript"),
Script: &[]string{restartScript},
})

if err != nil {
return err
}

// wait for load balancer probe to change the vm health status
time.Sleep(30 * time.Second)
timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour)
defer cancel()
err = d.waitForReadiness(timeoutCtx, resourceGroupName, vmssName, *vm.InstanceID)
if err != nil {
return err
}
}

return nil
}

func (d *deployer) waitForReadiness(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) error {
return wait.PollImmediateUntil(10*time.Second, func() (bool, error) {
return d.isVMInstanceHealthy(ctx, resourceGroupName, vmssName, vmInstanceID), nil
}, ctx.Done())
}

func (d *deployer) isVMInstanceHealthy(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) bool {
r, err := d.vmssvms.GetInstanceView(ctx, resourceGroupName, vmssName, vmInstanceID)
instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy"
if err != nil || instanceUnhealthy {
d.log.Printf("instance %s is unhealthy", vmInstanceID)
return false
}
return true
}
Loading