Skip to content

Commit fd46cb8

Browse files
authored
nvidia-device-plugin: Fix binary path when using MIG (#7201)
Signed-off-by: Suraj Deshmukh <[email protected]>
1 parent 622536f commit fd46cb8

File tree

103 files changed

+301
-120
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+301
-120
lines changed

e2e/scenario_gpu_managed_experience_test.go

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package e2e
33
import (
44
"context"
55
"testing"
6+
"time"
67

78
"github.com/Azure/agentbaker/e2e/components"
89
"github.com/Azure/agentbaker/e2e/config"
@@ -63,10 +64,10 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning(t *testing.T) {
6364
ValidateNvidiaDevicePluginServiceRunning(ctx, s)
6465

6566
// Validate that GPU resources are advertised by the device plugin
66-
ValidateNodeAdvertisesGPUResources(ctx, s)
67+
ValidateNodeAdvertisesGPUResources(ctx, s, 1)
6768

6869
// Validate that GPU workloads can be scheduled
69-
ValidateGPUWorkloadSchedulable(ctx, s)
70+
ValidateGPUWorkloadSchedulable(ctx, s, 1)
7071

7172
// Validate that the NVIDIA DCGM packages were installed correctly
7273
for _, packageName := range getDCGMPackageNames(os) {
@@ -77,7 +78,7 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning(t *testing.T) {
7778

7879
ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
7980
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
80-
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s)
81+
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_UTIL")
8182
},
8283
},
8384
})
@@ -118,10 +119,10 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning(t *testing.T) {
118119
ValidateNvidiaDevicePluginServiceRunning(ctx, s)
119120

120121
// Validate that GPU resources are advertised by the device plugin
121-
ValidateNodeAdvertisesGPUResources(ctx, s)
122+
ValidateNodeAdvertisesGPUResources(ctx, s, 1)
122123

123124
// Validate that GPU workloads can be scheduled
124-
ValidateGPUWorkloadSchedulable(ctx, s)
125+
ValidateGPUWorkloadSchedulable(ctx, s, 1)
125126

126127
for _, packageName := range getDCGMPackageNames(os) {
127128
versions := components.GetExpectedPackageVersions(packageName, os, osVersion)
@@ -131,7 +132,7 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning(t *testing.T) {
131132

132133
ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
133134
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
134-
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s)
135+
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_UTIL")
135136
},
136137
},
137138
})
@@ -172,10 +173,10 @@ func Test_AzureLinux3_NvidiaDevicePluginRunning(t *testing.T) {
172173
ValidateNvidiaDevicePluginServiceRunning(ctx, s)
173174

174175
// Validate that GPU resources are advertised by the device plugin
175-
ValidateNodeAdvertisesGPUResources(ctx, s)
176+
ValidateNodeAdvertisesGPUResources(ctx, s, 1)
176177

177178
// Validate that GPU workloads can be scheduled
178-
ValidateGPUWorkloadSchedulable(ctx, s)
179+
ValidateGPUWorkloadSchedulable(ctx, s, 1)
179180

180181
for _, packageName := range getDCGMPackageNames(os) {
181182
versions := components.GetExpectedPackageVersions(packageName, os, osVersion)
@@ -185,7 +186,70 @@ func Test_AzureLinux3_NvidiaDevicePluginRunning(t *testing.T) {
185186

186187
ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
187188
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
188-
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s)
189+
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_UTIL")
190+
},
191+
},
192+
})
193+
}
194+
195+
func Test_Ubuntu2404_NvidiaDevicePluginRunning_MIG(t *testing.T) {
196+
RunScenario(t, &Scenario{
197+
Description: "Tests that NVIDIA device plugin and DCGM Exporter work with MIG enabled on Ubuntu 24.04 GPU nodes",
198+
Tags: Tags{
199+
GPU: true,
200+
},
201+
Config: Config{
202+
Cluster: ClusterKubenet,
203+
VHD: config.VHDUbuntu2404Gen2Containerd,
204+
WaitForSSHAfterReboot: 5 * time.Minute,
205+
BootstrapConfigMutator: func(nbc *datamodel.NodeBootstrappingConfiguration) {
206+
nbc.AgentPoolProfile.VMSize = "Standard_NC24ads_A100_v4"
207+
nbc.ConfigGPUDriverIfNeeded = true
208+
nbc.EnableGPUDevicePluginIfNeeded = true
209+
nbc.EnableNvidia = true
210+
nbc.GPUInstanceProfile = "MIG2g"
211+
},
212+
VMConfigMutator: func(vmss *armcompute.VirtualMachineScaleSet) {
213+
vmss.SKU.Name = to.Ptr("Standard_NC24ads_A100_v4")
214+
if vmss.Tags == nil {
215+
vmss.Tags = map[string]*string{}
216+
}
217+
vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true")
218+
},
219+
Validator: func(ctx context.Context, s *Scenario) {
220+
os := "ubuntu"
221+
osVersion := "r2404"
222+
223+
// Validate that the NVIDIA device plugin binary was installed correctly
224+
versions := components.GetExpectedPackageVersions("nvidia-device-plugin", os, osVersion)
225+
require.Lenf(s.T, versions, 1, "Expected exactly one nvidia-device-plugin version for %s %s but got %d", os, osVersion, len(versions))
226+
ValidateInstalledPackageVersion(ctx, s, "nvidia-device-plugin", versions[0])
227+
228+
// Validate that the NVIDIA device plugin systemd service is running
229+
ValidateNvidiaDevicePluginServiceRunning(ctx, s)
230+
231+
// Validate that MIG mode is enabled via nvidia-smi
232+
ValidateMIGModeEnabled(ctx, s)
233+
234+
// Validate that MIG instances are created
235+
ValidateMIGInstancesCreated(ctx, s, "MIG 2g.20gb")
236+
237+
// Validate that GPU resources are advertised by the device plugin
238+
ValidateNodeAdvertisesGPUResources(ctx, s, 3)
239+
240+
// Validate that MIG workloads can be scheduled
241+
ValidateGPUWorkloadSchedulable(ctx, s, 3)
242+
243+
// Validate that the NVIDIA DCGM packages were installed correctly
244+
for _, packageName := range getDCGMPackageNames(os) {
245+
versions := components.GetExpectedPackageVersions(packageName, os, osVersion)
246+
require.Lenf(s.T, versions, 1, "Expected exactly one %s version for %s %s but got %d", packageName, os, osVersion, len(versions))
247+
ValidateInstalledPackageVersion(ctx, s, packageName, versions[0])
248+
}
249+
250+
ValidateNvidiaDCGMExporterSystemDServiceRunning(ctx, s)
251+
ValidateNvidiaDCGMExporterIsScrapable(ctx, s)
252+
ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx, s, "DCGM_FI_DEV_GPU_TEMP")
189253
},
190254
},
191255
})

e2e/test_helpers.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
2323
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
2424
"github.com/stretchr/testify/require"
25+
"k8s.io/apimachinery/pkg/util/wait"
2526
ctrruntimelog "sigs.k8s.io/controller-runtime/pkg/log"
2627
"sigs.k8s.io/controller-runtime/pkg/log/zap"
2728
)
@@ -649,7 +650,81 @@ func CreateSIGImageVersionFromDisk(ctx context.Context, s *Scenario, version str
649650
return &customVHD
650651
}
651652

653+
// isRebootRelatedSSHError checks if the error is related to a system reboot
654+
func isRebootRelatedSSHError(err error, stderr string) bool {
655+
if err == nil {
656+
return false
657+
}
658+
659+
rebootIndicators := []string{
660+
"System is going down",
661+
"pam_nologin",
662+
"Connection closed by",
663+
"Connection refused",
664+
"Connection timed out",
665+
}
666+
667+
errMsg := err.Error()
668+
for _, indicator := range rebootIndicators {
669+
if strings.Contains(errMsg, indicator) || strings.Contains(stderr, indicator) {
670+
return true
671+
}
672+
}
673+
return false
674+
}
675+
652676
func validateSSHConnectivity(ctx context.Context, s *Scenario) error {
677+
// If WaitForSSHAfterReboot is not set, use the original single-attempt behavior
678+
if s.Config.WaitForSSHAfterReboot == 0 {
679+
return attemptSSHConnection(ctx, s)
680+
}
681+
682+
// Retry logic with exponential backoff for scenarios that may reboot
683+
s.T.Logf("SSH connectivity validation will retry for up to %s if reboot-related errors are encountered", s.Config.WaitForSSHAfterReboot)
684+
startTime := time.Now()
685+
var lastSSHError error
686+
687+
err := wait.PollUntilContextTimeout(ctx, 5*time.Second, s.Config.WaitForSSHAfterReboot, true, func(ctx context.Context) (bool, error) {
688+
err := attemptSSHConnection(ctx, s)
689+
if err == nil {
690+
elapsed := time.Since(startTime)
691+
s.T.Logf("SSH connectivity established after %s", toolkit.FormatDuration(elapsed))
692+
return true, nil
693+
}
694+
695+
// Save the last error for better error messages
696+
lastSSHError = err
697+
698+
// Extract stderr from the error
699+
stderr := ""
700+
if strings.Contains(err.Error(), "Stderr:") {
701+
parts := strings.Split(err.Error(), "Stderr:")
702+
if len(parts) > 1 {
703+
stderr = parts[1]
704+
}
705+
}
706+
707+
// Check if this is a reboot-related error
708+
if isRebootRelatedSSHError(err, stderr) {
709+
s.T.Logf("Detected reboot-related SSH error, will retry: %v", err)
710+
return false, nil // Continue polling
711+
}
712+
713+
// Not a reboot error, fail immediately
714+
return false, err
715+
})
716+
717+
// If we timed out while retrying reboot-related errors, provide a better error message
718+
if err != nil && lastSSHError != nil {
719+
elapsed := time.Since(startTime)
720+
return fmt.Errorf("SSH connection failed after waiting %s for node to reboot and come back up. Last SSH error: %w", toolkit.FormatDuration(elapsed), lastSSHError)
721+
}
722+
723+
return err
724+
}
725+
726+
// attemptSSHConnection performs a single SSH connectivity check
727+
func attemptSSHConnection(ctx context.Context, s *Scenario) error {
653728
connectionTest := fmt.Sprintf("%s echo 'SSH_CONNECTION_OK'", sshString(s.Runtime.VMPrivateIP))
654729
connectionResult, err := execOnPrivilegedPod(ctx, s.Runtime.Cluster.Kube, defaultNamespace, s.Runtime.Cluster.DebugPod.Name, connectionTest)
655730

e2e/types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strconv"
1111
"strings"
1212
"testing"
13+
"time"
1314

1415
aksnodeconfigv1 "github.com/Azure/agentbaker/aks-node-controller/pkg/gen/aksnodeconfig/v1"
1516
"github.com/Azure/agentbaker/e2e/config"
@@ -166,6 +167,11 @@ type Config struct {
166167
// It shouldn't be used for majority of scenarios, currently only used for scenarios where the node is not expected to be reachable via ssh
167168
SkipSSHConnectivityValidation bool
168169

170+
// WaitForSSHAfterReboot if set to non-zero duration, SSH connectivity validation will retry with exponential backoff
171+
// for up to this duration when encountering reboot-related errors. This is useful for scenarios where the node
172+
// reboots during provisioning (e.g., MIG-enabled GPU nodes). Default (zero value) means no retry.
173+
WaitForSSHAfterReboot time.Duration
174+
169175
// if VHDCaching is set then a VHD will be created first for the test scenario and then a VM will be created from that VHD.
170176
// The main purpose is to validate VHD Caching logic and ensure a reboot step between basePrep and nodePrep doesn't break anything.
171177
VHDCaching bool

e2e/validators.go

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -968,29 +968,29 @@ func ValidateNvidiaDevicePluginServiceRunning(ctx context.Context, s *Scenario)
968968
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "NVIDIA device plugin systemd service should be active and enabled")
969969
}
970970

971-
func ValidateNodeAdvertisesGPUResources(ctx context.Context, s *Scenario) {
971+
func ValidateNodeAdvertisesGPUResources(ctx context.Context, s *Scenario, gpuCountExpected int64) {
972972
s.T.Helper()
973973
s.T.Logf("validating that node advertises GPU resources")
974+
resourceName := "nvidia.com/gpu"
974975

975976
// First, wait for the nvidia.com/gpu resource to be available
976-
waitUntilResourceAvailable(ctx, s, "nvidia.com/gpu")
977+
waitUntilResourceAvailable(ctx, s, resourceName)
977978

978979
// Get the node using the Kubernetes client from the test framework
979980
nodeName := s.Runtime.KubeNodeName
980981
node, err := s.Runtime.Cluster.Kube.Typed.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{})
981982
require.NoError(s.T, err, "failed to get node %q", nodeName)
982983

983984
// Check if the node advertises GPU capacity
984-
gpuCapacity, exists := node.Status.Capacity["nvidia.com/gpu"]
985-
require.True(s.T, exists, "node should advertise nvidia.com/gpu capacity")
985+
gpuCapacity, exists := node.Status.Capacity[corev1.ResourceName(resourceName)]
986+
require.True(s.T, exists, "node should advertise resource %s", resourceName)
986987

987988
gpuCount := gpuCapacity.Value()
988-
require.Greater(s.T, gpuCount, int64(0), "node should advertise at least 1 GPU, but got %d", gpuCount)
989-
990-
s.T.Logf("node %s advertises %d nvidia.com/gpu resources", nodeName, gpuCount)
989+
require.Equal(s.T, gpuCount, gpuCountExpected, "node should advertise %s=%d, but got %s=%d", resourceName, gpuCountExpected, resourceName, gpuCount)
990+
s.T.Logf("node %s advertises %s=%d resources", nodeName, resourceName, gpuCount)
991991
}
992992

993-
func ValidateGPUWorkloadSchedulable(ctx context.Context, s *Scenario) {
993+
func ValidateGPUWorkloadSchedulable(ctx context.Context, s *Scenario, gpuCount int) {
994994
s.T.Helper()
995995
s.T.Logf("validating that GPU workloads can be scheduled")
996996

@@ -1014,7 +1014,7 @@ func ValidateGPUWorkloadSchedulable(ctx context.Context, s *Scenario) {
10141014
},
10151015
Resources: corev1.ResourceRequirements{
10161016
Limits: corev1.ResourceList{
1017-
"nvidia.com/gpu": resource.MustParse("1"),
1017+
"nvidia.com/gpu": resource.MustParse(fmt.Sprintf("%d", gpuCount)),
10181018
},
10191019
},
10201020
},
@@ -1156,12 +1156,48 @@ func ValidateNvidiaDCGMExporterIsScrapable(ctx context.Context, s *Scenario) {
11561156
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "Nvidia DCGM Exporter is not scrapable on port 19400")
11571157
}
11581158

1159-
func ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx context.Context, s *Scenario) {
1159+
func ValidateNvidiaDCGMExporterScrapeCommonMetric(ctx context.Context, s *Scenario, metric string) {
11601160
s.T.Helper()
11611161
command := []string{
11621162
"set -ex",
11631163
// Verify the most universal GPU metric is present
1164-
"curl -s http://localhost:19400/metrics | grep -q 'DCGM_FI_DEV_GPU_UTIL'",
1164+
"curl -s http://localhost:19400/metrics | grep -q '" + metric + "'",
1165+
}
1166+
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "Nvidia DCGM Exporter is not returning "+metric)
1167+
}
1168+
1169+
func ValidateMIGModeEnabled(ctx context.Context, s *Scenario) {
1170+
s.T.Helper()
1171+
s.T.Logf("validating that MIG mode is enabled")
1172+
1173+
command := []string{
1174+
"set -ex",
1175+
// Grep to verify it contains 'Enabled' - this will fail if MIG is disabled
1176+
"sudo nvidia-smi --query-gpu=mig.mode.current --format=csv,noheader | grep -i 'Enabled'",
11651177
}
1166-
execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "Nvidia DCGM Exporter is not returning DCGM_FI_DEV_GPU_UTIL")
1178+
execResult := execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "MIG mode is not enabled")
1179+
1180+
stdout := strings.TrimSpace(execResult.stdout.String())
1181+
s.T.Logf("MIG mode status: %s", stdout)
1182+
require.Contains(s.T, stdout, "Enabled", "expected MIG mode to be enabled, but got: %s", stdout)
1183+
s.T.Logf("MIG mode is enabled")
1184+
}
1185+
1186+
func ValidateMIGInstancesCreated(ctx context.Context, s *Scenario, migProfile string) {
1187+
s.T.Helper()
1188+
s.T.Logf("validating that MIG instances are created with profile %s", migProfile)
1189+
1190+
command := []string{
1191+
"set -ex",
1192+
// List MIG devices using nvidia-smi
1193+
"sudo nvidia-smi mig -lgi",
1194+
// Ensure the output contains the expected MIG profile (will fail if "No MIG-enabled devices found")
1195+
"sudo nvidia-smi mig -lgi | grep -v 'No MIG-enabled devices found' | grep -q '" + migProfile + "'",
1196+
}
1197+
execResult := execScriptOnVMForScenarioValidateExitCode(ctx, s, strings.Join(command, "\n"), 0, "MIG instances with profile "+migProfile+" were not found")
1198+
1199+
stdout := execResult.stdout.String()
1200+
require.Contains(s.T, stdout, migProfile, "expected to find MIG profile %s in output, but did not.\nOutput:\n%s", migProfile, stdout)
1201+
require.NotContains(s.T, stdout, "No MIG-enabled devices found", "no MIG devices were created.\nOutput:\n%s", stdout)
1202+
s.T.Logf("MIG instances with profile %s are created", migProfile)
11671203
}

parts/linux/cloud-init/artifacts/cse_config.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ startNvidiaManagedExpServices() {
11501150
[Service]
11511151
Environment="MIG_STRATEGY=--mig-strategy single"
11521152
ExecStart=
1153-
ExecStart=/usr/local/bin/nvidia-device-plugin $MIG_STRATEGY
1153+
ExecStart=/usr/bin/nvidia-device-plugin $MIG_STRATEGY
11541154
EOF
11551155
# Reload systemd to pick up the base path override
11561156
systemctl daemon-reload

0 commit comments

Comments
 (0)