Skip to content

Commit 8db1ba7

Browse files
committed
wip: handle driver migrations in nvidiadriver controller
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent 81640fd commit 8db1ba7

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

controllers/nvidiadriver_controller.go

+111
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
corev1 "k8s.io/api/core/v1"
2727
apierrors "k8s.io/apimachinery/pkg/api/errors"
2828
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
29+
"k8s.io/apimachinery/pkg/labels"
2930
"k8s.io/apimachinery/pkg/runtime"
3031
"k8s.io/apimachinery/pkg/types"
3132
"k8s.io/client-go/util/workqueue"
@@ -168,6 +169,10 @@ func (r *NVIDIADriverReconciler) Reconcile(ctx context.Context, req ctrl.Request
168169
return reconcile.Result{}, nil
169170
}
170171

172+
err = updateNodesManagedByDriver(ctx, r, instance)
173+
if err != nil {
174+
return reconcile.Result{}, fmt.Errorf("failed to update nodes managed by driver: %w", err)
175+
}
171176
// Sync state and update status
172177
managerStatus := r.stateManager.SyncState(ctx, instance, infoCatalog)
173178

@@ -404,5 +409,111 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.
404409
return fmt.Errorf("failed to add index key: %w", err)
405410
}
406411

412+
if err := mgr.GetFieldIndexer().IndexField(ctx, &corev1.Pod{}, "spec.nodeName", func(rawObj client.Object) []string {
413+
pod := rawObj.(*corev1.Pod)
414+
return []string{pod.Spec.NodeName}
415+
}); err != nil {
416+
return err
417+
}
418+
419+
return nil
420+
}
421+
422+
func updateNodesManagedByDriver(ctx context.Context, r NVIDIADriverReconciler, instance *nvidiav1alpha1.NVIDIADriver) error {
423+
nodes, err := getNVIDIADriverSelectedNodes(ctx, r.Client, instance)
424+
if err != nil {
425+
return fmt.Errorf("failed to get selected nodes for NVIDIADriver CR: %w", err)
426+
}
427+
428+
// A map tracking which node objects need to be updated. E.g. updated label / annotations
429+
// need to be applied.
430+
nodesToUpdate := map[*corev1.Node]struct{}{}
431+
432+
for _, node := range nodes.Items {
433+
labels := node.GetLabels()
434+
annotations := node.GetAnnotations()
435+
436+
managedBy, exists := labels["nvidia.com/gpu.driver.managed-by"]
437+
if !exists {
438+
// if 'managed-by' label does not exist, label node with cr.Name
439+
labels["nvidia.com/gpu.driver.managed-by"] = instance.Name
440+
nodesToUpdate[&node] = struct{}{}
441+
// if there is an orphan driver pod running on the node,
442+
// indicate to the upgrade controller that an upgrade is required
443+
podList := &corev1.PodList{}
444+
err = r.Client.List(ctx, podList,
445+
client.InNamespace("gpu-operator"),
446+
client.MatchingLabels(map[string]string{DriverLabelKey: DriverLabelValue}),
447+
client.MatchingFields{"spec.nodeName": node.Name})
448+
if err != nil {
449+
return fmt.Errorf("failed to list driver pods: %w", err)
450+
}
451+
if len(podList.Items) == 0 {
452+
continue
453+
}
454+
if len(podList.Items) != 1 {
455+
return fmt.Errorf("there are multiple driver pods running on node %s", node.Name)
456+
}
457+
pod := podList.Items[0]
458+
if pod.OwnerReferences == nil || len(pod.OwnerReferences) == 0 {
459+
annotations["nvidia.com/gpu-driver-upgrade-requested"] = "true"
460+
}
461+
continue
462+
}
463+
464+
// do nothing if node is already being managed by this CR
465+
if managedBy == instance.Name {
466+
continue
467+
}
468+
469+
}
470+
471+
// Apply updated labels / annotations on node objects
472+
for node := range nodesToUpdate {
473+
err = r.Client.Update(ctx, node)
474+
if err != nil {
475+
return fmt.Errorf("failed to update node %s: %w", node.Name, err)
476+
}
477+
}
478+
407479
return nil
408480
}
481+
482+
// getNVIDIADriverSelectedNodes returns selected nodes based on the nodeselector labels set for a given NVIDIADriver instance
483+
func getNVIDIADriverSelectedNodes(ctx context.Context, k8sClient client.Client, cr *nvidiav1alpha1.NVIDIADriver) (*corev1.NodeList, error) {
484+
nodeList := &corev1.NodeList{}
485+
486+
if cr.Spec.NodeSelector == nil {
487+
cr.Spec.NodeSelector = cr.GetNodeSelector()
488+
}
489+
490+
selector := labels.Set(cr.Spec.NodeSelector).AsSelector()
491+
492+
opts := []client.ListOption{
493+
client.MatchingLabelsSelector{Selector: selector},
494+
}
495+
err := k8sClient.List(ctx, nodeList, opts...)
496+
497+
return nodeList, err
498+
}
499+
500+
/*
501+
func getDriverPodLabelSelector(clusterPolicy gpuv1.ClusterPolicy) map[string]string {
502+
// initialize with common app=nvidia-driver-daemonset label
503+
driverLabelKey := DriverLabelKey
504+
driverLabelValue := DriverLabelValue
505+
506+
if clusterPolicy.Spec.Driver.UseNvdiaDriverCRDType() {
507+
// app component label is added for all new driver daemonsets deployed by NVIDIADriver controller
508+
driverLabelKey = AppComponentLabelKey
509+
driverLabelValue = AppComponentLabelValue
510+
} else if clusterPolicyCtrl.openshift != "" && clusterPolicyCtrl.ocpDriverToolkit.enabled {
511+
// For OCP, when DTK is enabled app=nvidia-driver-daemonset label is not constant and changes
512+
// based on rhcos version. Hence use DTK label instead
513+
driverLabelKey = ocpDriverToolkitIdentificationLabel
514+
driverLabelValue = ocpDriverToolkitIdentificationValue
515+
}
516+
517+
return map[string]string{driverLabelKey: driverLabelValue}
518+
}
519+
*/

0 commit comments

Comments
 (0)