Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions e2e/scenario_gpu_managed_experience_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package e2e
import (
"context"
"testing"
"time"

"github.com/Azure/agentbaker/e2e/components"
"github.com/Azure/agentbaker/e2e/config"
Expand Down Expand Up @@ -63,10 +64,10 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning(t *testing.T) {
ValidateNvidiaDevicePluginServiceRunning(ctx, s)

// Validate that GPU resources are advertised by the device plugin
ValidateNodeAdvertisesGPUResources(ctx, s)
ValidateNodeAdvertisesGPUResources(ctx, s, 1)

// Validate that GPU workloads can be scheduled
ValidateGPUWorkloadSchedulable(ctx, s)
ValidateGPUWorkloadSchedulable(ctx, s, 1)

// Validate that the NVIDIA DCGM packages were installed correctly
for _, packageName := range getDCGMPackageNames(os) {
Expand All @@ -77,7 +78,7 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning(t *testing.T) {

ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_UTIL")
},
},
})
Expand Down Expand Up @@ -118,10 +119,10 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning(t *testing.T) {
ValidateNvidiaDevicePluginServiceRunning(ctx, s)

// Validate that GPU resources are advertised by the device plugin
ValidateNodeAdvertisesGPUResources(ctx, s)
ValidateNodeAdvertisesGPUResources(ctx, s, 1)

// Validate that GPU workloads can be scheduled
ValidateGPUWorkloadSchedulable(ctx, s)
ValidateGPUWorkloadSchedulable(ctx, s, 1)

for _, packageName := range getDCGMPackageNames(os) {
versions := components.GetExpectedPackageVersions(packageName, os, osVersion)
Expand All @@ -131,7 +132,7 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning(t *testing.T) {

ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_UTIL")
},
},
})
Expand Down Expand Up @@ -172,10 +173,10 @@ func Test_AzureLinux3_NvidiaDevicePluginRunning(t *testing.T) {
ValidateNvidiaDevicePluginServiceRunning(ctx, s)

// Validate that GPU resources are advertised by the device plugin
ValidateNodeAdvertisesGPUResources(ctx, s)
ValidateNodeAdvertisesGPUResources(ctx, s, 1)

// Validate that GPU workloads can be scheduled
ValidateGPUWorkloadSchedulable(ctx, s)
ValidateGPUWorkloadSchedulable(ctx, s, 1)

for _, packageName := range getDCGMPackageNames(os) {
versions := components.GetExpectedPackageVersions(packageName, os, osVersion)
Expand All @@ -185,7 +186,70 @@ func Test_AzureLinux3_NvidiaDevicePluginRunning(t *testing.T) {

ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_UTIL")
},
},
})
}

func Test_Ubuntu2404_NvidiaDevicePluginRunning_MIG(t *testing.T) {
RunScenario(t, &Scenario{
Description: "Tests that NVIDIA device plugin and DCGM Exporter work with MIG enabled on Ubuntu 24.04 GPU nodes",
Tags: Tags{
GPU: true,
},
Config: Config{
Cluster: ClusterKubenet,
VHD: config.VHDUbuntu2404Gen2Containerd,
WaitForSSHAfterReboot: 5 * time.Minute,
BootstrapConfigMutator: func(nbc *datamodel.NodeBootstrappingConfiguration) {
nbc.AgentPoolProfile.VMSize = "Standard_NC24ads_A100_v4"
nbc.ConfigGPUDriverIfNeeded = true
nbc.EnableGPUDevicePluginIfNeeded = true
nbc.EnableNvidia = true
nbc.GPUInstanceProfile = "MIG2g"
},
VMConfigMutator: func(vmss *armcompute.VirtualMachineScaleSet) {
vmss.SKU.Name = to.Ptr("Standard_NC24ads_A100_v4")
if vmss.Tags == nil {
vmss.Tags = map[string]*string{}
}
vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true")
},
Validator: func(ctx context.Context, s *Scenario) {
os := "ubuntu"
osVersion := "r2404"

// Validate that the NVIDIA device plugin binary was installed correctly
versions := components.GetExpectedPackageVersions("nvidia-device-plugin", os, osVersion)
require.Lenf(s.T, versions, 1, "Expected exactly one nvidia-device-plugin version for %s %s but got %d", os, osVersion, len(versions))
ValidateInstalledPackageVersion(ctx, s, "nvidia-device-plugin", versions[0])

// Validate that the NVIDIA device plugin systemd service is running
ValidateNvidiaDevicePluginServiceRunning(ctx, s)

// Validate that MIG mode is enabled via nvidia-smi
ValidateMIGModeEnabled(ctx, s)

// Validate that MIG instances are created
ValidateMIGInstancesCreated(ctx, s, "MIG 2g.20gb")

// Validate that GPU resources are advertised by the device plugin
ValidateNodeAdvertisesGPUResources(ctx, s, 3)

// Validate that MIG workloads can be scheduled
ValidateGPUWorkloadSchedulable(ctx, s, 3)

// Validate that the NVIDIA DCGM packages were installed correctly
for _, packageName := range getDCGMPackageNames(os) {
versions := components.GetExpectedPackageVersions(packageName, os, osVersion)
require.Lenf(s.T, versions, 1, "Expected exactly one %s version for %s %s but got %d", packageName, os, osVersion, len(versions))
ValidateInstalledPackageVersion(ctx, s, packageName, versions[0])
}

ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_TEMP")
},
},
})
Expand Down
75 changes: 75 additions & 0 deletions e2e/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/wait"
ctrruntimelog "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
)
Expand Down Expand Up @@ -649,7 +650,81 @@ func CreateSIGImageVersionFromDisk(ctx context.Context, s *Scenario, version str
return &customVHD
}

// isRebootRelatedSSHError checks if the error is related to a system reboot
func isRebootRelatedSSHError(err error, stderr string) bool {
if err == nil {
return false
}

rebootIndicators := []string{
"System is going down",
"pam_nologin",
"Connection closed by",
"Connection refused",
"Connection timed out",
}

errMsg := err.Error()
for _, indicator := range rebootIndicators {
if strings.Contains(errMsg, indicator) || strings.Contains(stderr, indicator) {
return true
}
}
return false
}

func validateSSHConnectivity(ctx context.Context, s *Scenario) error {
// If WaitForSSHAfterReboot is not set, use the original single-attempt behavior
if s.Config.WaitForSSHAfterReboot == 0 {
return attemptSSHConnection(ctx, s)
}

// Retry logic with exponential backoff for scenarios that may reboot
s.T.Logf("SSH connectivity validation will retry for up to %s if reboot-related errors are encountered", s.Config.WaitForSSHAfterReboot)
startTime := time.Now()
var lastSSHError error

err := wait.PollUntilContextTimeout(ctx, 5*time.Second, s.Config.WaitForSSHAfterReboot, true, func(ctx context.Context) (bool, error) {
err := attemptSSHConnection(ctx, s)
if err == nil {
elapsed := time.Since(startTime)
s.T.Logf("SSH connectivity established after %s", toolkit.FormatDuration(elapsed))
return true, nil
}

// Save the last error for better error messages
lastSSHError = err

// Extract stderr from the error
stderr := ""
if strings.Contains(err.Error(), "Stderr:") {
parts := strings.Split(err.Error(), "Stderr:")
if len(parts) > 1 {
stderr = parts[1]
}
}

// Check if this is a reboot-related error
if isRebootRelatedSSHError(err, stderr) {
s.T.Logf("Detected reboot-related SSH error, will retry: %v", err)
return false, nil // Continue polling
}

// Not a reboot error, fail immediately
return false, err
})

// If we timed out while retrying reboot-related errors, provide a better error message
if err != nil && lastSSHError != nil {
elapsed := time.Since(startTime)
return fmt.Errorf("SSH connection failed after waiting %s for node to reboot and come back up. Last SSH error: %w", toolkit.FormatDuration(elapsed), lastSSHError)
}

return err
}

// attemptSSHConnection performs a single SSH connectivity check
func attemptSSHConnection(ctx context.Context, s *Scenario) error {
connectionTest := fmt.Sprintf("%s echo 'SSH_CONNECTION_OK'", sshString(s.Runtime.VMPrivateIP))
connectionResult, err := execOnPrivilegedPod(ctx, s.Runtime.Cluster.Kube, defaultNamespace, s.Runtime.Cluster.DebugPod.Name, connectionTest)

Expand Down
6 changes: 6 additions & 0 deletions e2e/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"testing"
"time"

aksnodeconfigv1 "github.com/Azure/agentbaker/aks-node-controller/pkg/gen/aksnodeconfig/v1"
"github.com/Azure/agentbaker/e2e/config"
Expand Down Expand Up @@ -166,6 +167,11 @@ type Config struct {
// It shouldn't be used for majority of scenarios, currently only used for scenarios where the node is not expected to be reachable via ssh
SkipSSHConnectivityValidation bool

// WaitForSSHAfterReboot if set to non-zero duration, SSH connectivity validation will retry with exponential backoff
// for up to this duration when encountering reboot-related errors. This is useful for scenarios where the node
// reboots during provisioning (e.g., MIG-enabled GPU nodes). Default (zero value) means no retry.
WaitForSSHAfterReboot time.Duration

// if VHDCaching is set then a VHD will be created first for the test scenario and then a VM will be created from that VHD.
// The main purpose is to validate VHD Caching logic and ensure a reboot step between basePrep and nodePrep doesn't break anything.
VHDCaching bool
Expand Down
60 changes: 48 additions & 12 deletions e2e/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -968,29 +968,29 @@ func ValidateNvidiaDevicePluginServiceRunning(ctx context.Context, s *Scenario)
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "NVIDIA device plugin systemd service should be active and enabled")
}

func ValidateNodeAdvertisesGPUResources(ctx context.Context, s *Scenario) {
func ValidateNodeAdvertisesGPUResources(ctx context.Context, s *Scenario, gpuCountExpected int64) {
s.T.Helper()
s.T.Logf("validating that node advertises GPU resources")
resourceName := "nvidia.com/gpu"

// First, wait for the nvidia.com/gpu resource to be available
waitUntilResourceAvailable(ctx, s, "nvidia.com/gpu")
waitUntilResourceAvailable(ctx, s, resourceName)

// Get the node using the Kubernetes client from the test framework
nodeName := s.Runtime.KubeNodeName
node, err := s.Runtime.Cluster.Kube.Typed.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{})
require.NoError(s.T, err, "failed to get node %q", nodeName)

// Check if the node advertises GPU capacity
gpuCapacity, exists := node.Status.Capacity["nvidia.com/gpu"]
require.True(s.T, exists, "node should advertise nvidia.com/gpu capacity")
gpuCapacity, exists := node.Status.Capacity[corev1.ResourceName(resourceName)]
require.True(s.T, exists, "node should advertise resource %s", resourceName)

gpuCount := gpuCapacity.Value()
require.Greater(s.T, gpuCount, int64(0), "node should advertise at least 1 GPU, but got %d", gpuCount)

s.T.Logf("node %s advertises %d nvidia.com/gpu resources", nodeName, gpuCount)
require.Equal(s.T, gpuCount, gpuCountExpected, "node should advertise %s=%d, but got %s=%d", resourceName, gpuCountExpected, resourceName, gpuCount)
s.T.Logf("node %s advertises %s=%d resources", nodeName, resourceName, gpuCount)
}

func ValidateGPUWorkloadSchedulable(ctx context.Context, s *Scenario) {
func ValidateGPUWorkloadSchedulable(ctx context.Context, s *Scenario, gpuCount int) {
s.T.Helper()
s.T.Logf("validating that GPU workloads can be scheduled")

Expand All @@ -1014,7 +1014,7 @@ func ValidateGPUWorkloadSchedulable(ctx context.Context, s *Scenario) {
},
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
"nvidia.com/gpu": resource.MustParse("1"),
"nvidia.com/gpu": resource.MustParse(fmt.Sprintf("%d", gpuCount)),
},
},
},
Expand Down Expand Up @@ -1156,12 +1156,48 @@ func ValidateNvidiaDCGMExporterIsScrapable(ctx context.Context, s *Scenario) {
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "Nvidia DCGM Exporter is not scrapable on port 19400")
}

func ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx context.Context, s *Scenario) {
func ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx context.Context, s *Scenario, metric string) {
s.T.Helper()
command := []string{
"set -ex",
// Verify the most universal GPU metric is present
"curl -s http://localhost:19400/metrics | grep -q 'DCGM_FI_DEV_GPU_UTIL'",
"curl -s http://localhost:19400/metrics | grep -q '" + metric + "'",
}
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "Nvidia DCGM Exporter is not returning "+metric)
}

func ValidateMIGModeEnabled(ctx context.Context, s *Scenario) {
s.T.Helper()
s.T.Logf("validating that MIG mode is enabled")

command := []string{
"set -ex",
// Grep to verify it contains 'Enabled' - this will fail if MIG is disabled
"sudo nvidia-smi --query-gpu=mig.mode.current --format=csv,noheader | grep -i 'Enabled'",
}
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "Nvidia DCGM Exporter is not returning DCGM_FI_DEV_GPU_UTIL")
execResult := execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "MIG mode is not enabled")

stdout := strings.TrimSpace(execResult.stdout.String())
s.T.Logf("MIG mode status: %s", stdout)
require.Contains(s.T, stdout, "Enabled", "expected MIG mode to be enabled, but got: %s", stdout)
s.T.Logf("MIG mode is enabled")
}

func ValidateMIGInstancesCreated(ctx context.Context, s *Scenario, migProfile string) {
s.T.Helper()
s.T.Logf("validating that MIG instances are created with profile %s", migProfile)

command := []string{
"set -ex",
// List MIG devices using nvidia-smi
"sudo nvidia-smi mig -lgi",
// Ensure the output contains the expected MIG profile (will fail if "No MIG-enabled devices found")
"sudo nvidia-smi mig -lgi | grep -v 'No MIG-enabled devices found' | grep -q '" + migProfile + "'",
}
execResult := execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "MIG instances with profile "+migProfile+" were not found")

stdout := execResult.stdout.String()
require.Contains(s.T, stdout, migProfile, "expected to find MIG profile %s in output, but did not.\nOutput:\n%s", migProfile, stdout)
require.NotContains(s.T, stdout, "No MIG-enabled devices found", "no MIG devices were created.\nOutput:\n%s", stdout)
s.T.Logf("MIG instances with profile %s are created", migProfile)
}
2 changes: 1 addition & 1 deletion parts/linux/cloud-init/artifacts/cse_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ startNvidiaManagedExpServices() {
[Service]
Environment="MIG_STRATEGY=--mig-strategy single"
ExecStart=
ExecStart=/usr/local/bin/nvidia-device-plugin $MIG_STRATEGY
ExecStart=/usr/bin/nvidia-device-plugin $MIG_STRATEGY
EOF
# Reload systemd to pick up the base path override
systemctl daemon-reload
Expand Down
Loading
Loading