From 7563bcaefee03a7f125dc9f8089939f1b41de84c Mon Sep 17 00:00:00 2001 From: Bryce Soghigian <49734722+Bryce-Soghigian@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:45:43 -0800 Subject: [PATCH] fix: properly garbage collecting orphaned network interfaces (#642) * feat: adding ListNics to instanceprovider interface alongside a refactor of arg related functions to their own file * feat: adding garbage collection logic for network interfaces and refactoring gc functions slightly * feat: working poc of ARG Queries and nic garbage collection, need to fix tests * fix: tests failing due to ListVM * feat: split nic and vm into their own gc controllers, added shared state between them to prevent conflicts in nic deletion calls * feat: Add DeleteNic option to instance provider * docs(code-comment): adding clarification to unremovableNics * test: instanceprovider.ListNics barebones test and arg fake list nic impl * fix: bootstrap.sh * feat: adding in VM List into the networkinterface.garbagecollection controller to avoid attempts to delete nics that are managed by a vm * refactor: unremovableNics out of the vm gc controller in favor for a cleaner state list * fix:updating references to cache * Update pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go * test: adding composable network interface options to our test utils based on existing karp-core pattern * test: adding test we don't include unmanaged nics * test: adding network garbage collection suite and happy path * test: adding tests for unremovable nics * test: adding coverage that vm controller cleans up nics * refactor: renaming controller * fix: refactor name * refactor: using import directly * ci: checking error for controller * fix: ci * fix: addressing comments * Update pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go * refactor: removing name constant * refactor: moving to test utils * fix: removing GetZoneID * style: improving the readability of the network interface garbage collection tests * revert: removing lo.FromPtr checks for nodeclaim creation to avoid creating nodeclaims with empty required properties * refactor: VirtualmachineProperties needs a default for time created * fix: modifying the tests to be aware of time created to properly simulate vm gc * refactor: added nodepoolName to test.VirtualMachine and test.Interface * refactor: moving vm gc tests to use test.VirtualMachine * refactor: renaming arg files to azureresourcegraph * refactor: using deleteIfNicExists directly in cleanupAzureResources * test: createNicFromQueryResponseData missing id, missing name and happy case * refactor: using fake.Region for the default region for test.VirtualMachine and test.NetworkInterface * fix: using input.InterfaceName rather than the outer scope interface name * test: modifying fake to only initialize the query once --- pkg/cloudprovider/cloudprovider.go | 2 +- pkg/cloudprovider/suite_test.go | 2 +- pkg/controllers/controllers.go | 5 +- ...oller.go => instance_garbagecollection.go} | 16 +- .../nic_garbagecollection.go | 112 +++++++++++++ .../nodeclaim/garbagecollection/suite_test.go | 152 ++++++++++++------ pkg/fake/azureresourcegraphapi.go | 39 ++++- pkg/fake/azureresourcegraphapi_test.go | 4 +- pkg/fake/networkinterfaceapi.go | 3 +- .../instance/azureresourcegraphlist.go | 108 +++++++++++++ ...argutils.go => azureresourcegraphutils.go} | 0 ...mutils.go => azureresourcemanagerutils.go} | 0 pkg/providers/instance/instance.go | 92 +++++------ pkg/providers/instance/instance_test.go | 62 +++++++ pkg/providers/instance/suite_test.go | 25 ++- pkg/test/environment.go | 5 +- pkg/test/expectations/expectations.go | 8 +- pkg/test/networkinterfaces.go | 83 ++++++++++ pkg/test/utils.go | 35 ++++ pkg/test/virtualmachines.go | 78 +++++++++ 20 files changed, 712 insertions(+), 119 deletions(-) rename pkg/controllers/nodeclaim/garbagecollection/{controller.go => instance_garbagecollection.go} (87%) create mode 100644 pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go create mode 100644 pkg/providers/instance/azureresourcegraphlist.go rename pkg/providers/instance/{argutils.go => azureresourcegraphutils.go} (100%) rename pkg/providers/instance/{armutils.go => azureresourcemanagerutils.go} (100%) create mode 100644 pkg/test/networkinterfaces.go create mode 100644 pkg/test/utils.go create mode 100644 pkg/test/virtualmachines.go diff --git a/pkg/cloudprovider/cloudprovider.go b/pkg/cloudprovider/cloudprovider.go index 6ffaee0fb..fa5992bd8 100644 --- a/pkg/cloudprovider/cloudprovider.go +++ b/pkg/cloudprovider/cloudprovider.go @@ -138,6 +138,7 @@ func (c *CloudProvider) List(ctx context.Context) ([]*karpv1.NodeClaim, error) { if err != nil { return nil, fmt.Errorf("listing instances, %w", err) } + var nodeClaims []*karpv1.NodeClaim for _, instance := range instances { instanceType, err := c.resolveInstanceTypeFromInstance(ctx, instance) @@ -337,7 +338,6 @@ func (c *CloudProvider) instanceToNodeClaim(ctx context.Context, vm *armcompute. labels[karpv1.CapacityTypeLabelKey] = instance.GetCapacityType(vm) - // TODO: v1beta1 new kes/labels if tag, ok := vm.Tags[instance.NodePoolTagKey]; ok { labels[karpv1.NodePoolLabelKey] = *tag } diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index 0a9c86c0f..146502887 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -144,7 +144,7 @@ var _ = Describe("CloudProvider", func() { nodeClaims, _ := cloudProvider.List(ctx) Expect(azureEnv.AzureResourceGraphAPI.AzureResourceGraphResourcesBehavior.CalledWithInput.Len()).To(Equal(1)) queryRequest := azureEnv.AzureResourceGraphAPI.AzureResourceGraphResourcesBehavior.CalledWithInput.Pop().Query - Expect(*queryRequest.Query).To(Equal(instance.GetListQueryBuilder(azureEnv.AzureResourceGraphAPI.ResourceGroup).String())) + Expect(*queryRequest.Query).To(Equal(instance.GetVMListQueryBuilder(azureEnv.AzureResourceGraphAPI.ResourceGroup).String())) Expect(nodeClaims).To(HaveLen(1)) Expect(nodeClaims[0]).ToNot(BeNil()) resp, _ := azureEnv.VirtualMachinesAPI.Get(ctx, azureEnv.AzureResourceGraphAPI.ResourceGroup, nodeClaims[0].Name, nil) diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index e96be8104..793c1e93f 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -44,7 +44,10 @@ func NewControllers(ctx context.Context, mgr manager.Manager, kubeClient client. nodeclasshash.NewController(kubeClient), nodeclassstatus.NewController(kubeClient), nodeclasstermination.NewController(kubeClient, recorder), - nodeclaimgarbagecollection.NewController(kubeClient, cloudProvider), + + nodeclaimgarbagecollection.NewVirtualMachine(kubeClient, cloudProvider), + nodeclaimgarbagecollection.NewNetworkInterface(kubeClient, instanceProvider), + // TODO: nodeclaim tagging inplaceupdate.NewController(kubeClient, instanceProvider), status.NewController[*v1alpha2.AKSNodeClass](kubeClient, mgr.GetEventRecorderFor("karpenter")), diff --git a/pkg/controllers/nodeclaim/garbagecollection/controller.go b/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go similarity index 87% rename from pkg/controllers/nodeclaim/garbagecollection/controller.go rename to pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go index 033dc31f3..f86fc9ada 100644 --- a/pkg/controllers/nodeclaim/garbagecollection/controller.go +++ b/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go @@ -23,7 +23,6 @@ import ( "github.com/awslabs/operatorpkg/singleton" - // "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" "github.com/samber/lo" "go.uber.org/multierr" v1 "k8s.io/api/core/v1" @@ -41,21 +40,21 @@ import ( corecloudprovider "sigs.k8s.io/karpenter/pkg/cloudprovider" ) -type Controller struct { +type VirtualMachine struct { kubeClient client.Client cloudProvider corecloudprovider.CloudProvider - successfulCount uint64 // keeps track of successful reconciles for more aggressive requeueing near the start of the controller + successfulCount uint64 // keeps track of successful reconciles for more aggressive requeuing near the start of the controller } -func NewController(kubeClient client.Client, cloudProvider corecloudprovider.CloudProvider) *Controller { - return &Controller{ +func NewVirtualMachine(kubeClient client.Client, cloudProvider corecloudprovider.CloudProvider) *VirtualMachine { + return &VirtualMachine{ kubeClient: kubeClient, cloudProvider: cloudProvider, successfulCount: 0, } } -func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { +func (c *VirtualMachine) Reconcile(ctx context.Context) (reconcile.Result, error) { ctx = injection.WithControllerName(ctx, "instance.garbagecollection") // We LIST VMs on the CloudProvider BEFORE we grab NodeClaims/Nodes on the cluster so that we make sure that, if @@ -65,6 +64,7 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { if err != nil { return reconcile.Result{}, fmt.Errorf("listing cloudprovider VMs, %w", err) } + managedRetrieved := lo.Filter(retrieved, func(nc *karpv1.NodeClaim, _ int) bool { return nc.DeletionTimestamp.IsZero() }) @@ -93,7 +93,7 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { return reconcile.Result{RequeueAfter: lo.Ternary(c.successfulCount <= 20, time.Second*10, time.Minute*2)}, nil } -func (c *Controller) garbageCollect(ctx context.Context, nodeClaim *karpv1.NodeClaim, nodeList *v1.NodeList) error { +func (c *VirtualMachine) garbageCollect(ctx context.Context, nodeClaim *karpv1.NodeClaim, nodeList *v1.NodeList) error { ctx = logging.WithLogger(ctx, logging.FromContext(ctx).With("provider-id", nodeClaim.Status.ProviderID)) if err := c.cloudProvider.Delete(ctx, nodeClaim); err != nil { return corecloudprovider.IgnoreNodeClaimNotFoundError(err) @@ -112,7 +112,7 @@ func (c *Controller) garbageCollect(ctx context.Context, nodeClaim *karpv1.NodeC return nil } -func (c *Controller) Register(_ context.Context, m manager.Manager) error { +func (c *VirtualMachine) Register(_ context.Context, m manager.Manager) error { return controllerruntime.NewControllerManagedBy(m). Named("instance.garbagecollection"). WatchesRawSource(singleton.Source()). diff --git a/pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go b/pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go new file mode 100644 index 000000000..7571e79d8 --- /dev/null +++ b/pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go @@ -0,0 +1,112 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package garbagecollection + +import ( + "context" + "fmt" + "time" + + "github.com/samber/lo" + "knative.dev/pkg/logging" + + "github.com/awslabs/operatorpkg/singleton" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/util/workqueue" + controllerruntime "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" + "sigs.k8s.io/karpenter/pkg/operator/injection" + + "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" +) + +const ( + NicReservationDuration = time.Second * 180 + // We set this interval at 5 minutes, as thats how often our NRP limits are reset. + // See: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/request-limits-and-throttling#network-throttling + NicGarbageCollectionInterval = time.Minute * 5 +) + +type NetworkInterface struct { + kubeClient client.Client + instanceProvider instance.Provider +} + +func NewNetworkInterface(kubeClient client.Client, instanceProvider instance.Provider) *NetworkInterface { + return &NetworkInterface{ + kubeClient: kubeClient, + instanceProvider: instanceProvider, + } +} + +func (c *NetworkInterface) populateUnremovableInterfaces(ctx context.Context) (sets.Set[string], error) { + unremovableInterfaces := sets.New[string]() + vms, err := c.instanceProvider.List(ctx) + if err != nil { + return unremovableInterfaces, fmt.Errorf("listing VMs: %w", err) + } + for _, vm := range vms { + unremovableInterfaces.Insert(lo.FromPtr(vm.Name)) + } + nodeClaimList := &karpv1.NodeClaimList{} + if err := c.kubeClient.List(ctx, nodeClaimList); err != nil { + return unremovableInterfaces, fmt.Errorf("listing NodeClaims for NIC GC: %w", err) + } + + for _, nodeClaim := range nodeClaimList.Items { + unremovableInterfaces.Insert(instance.GenerateResourceName(nodeClaim.Name)) + } + return unremovableInterfaces, nil +} + +func (c *NetworkInterface) Reconcile(ctx context.Context) (reconcile.Result, error) { + ctx = injection.WithControllerName(ctx, "networkinterface.garbagecollection") + nics, err := c.instanceProvider.ListNics(ctx) + if err != nil { + return reconcile.Result{}, fmt.Errorf("listing NICs: %w", err) + } + + unremovableInterfaces, err := c.populateUnremovableInterfaces(ctx) + if err != nil { + return reconcile.Result{}, fmt.Errorf("error listing resources needed to populate unremovable nics %w", err) + } + workqueue.ParallelizeUntil(ctx, 100, len(nics), func(i int) { + nicName := lo.FromPtr(nics[i].Name) + if !unremovableInterfaces.Has(nicName) { + err := c.instanceProvider.DeleteNic(ctx, nicName) + if err != nil { + logging.FromContext(ctx).Error(err) + return + } + + logging.FromContext(ctx).With("nic", nicName).Infof("garbage collected NIC") + } + }) + return reconcile.Result{ + RequeueAfter: NicGarbageCollectionInterval, + }, nil +} + +func (c *NetworkInterface) Register(_ context.Context, m manager.Manager) error { + return controllerruntime.NewControllerManagedBy(m). + Named("networkinterface.garbagecollection"). + WatchesRawSource(singleton.Source()). + Complete(singleton.AsReconciler(c)) +} diff --git a/pkg/controllers/nodeclaim/garbagecollection/suite_test.go b/pkg/controllers/nodeclaim/garbagecollection/suite_test.go index 5e9a7edee..3197e5c94 100644 --- a/pkg/controllers/nodeclaim/garbagecollection/suite_test.go +++ b/pkg/controllers/nodeclaim/garbagecollection/suite_test.go @@ -32,9 +32,9 @@ import ( "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" "github.com/Azure/karpenter-provider-azure/pkg/controllers/nodeclaim/garbagecollection" - "github.com/Azure/karpenter-provider-azure/pkg/fake" "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" + . "github.com/Azure/karpenter-provider-azure/pkg/test/expectations" "github.com/Azure/karpenter-provider-azure/pkg/utils" . "github.com/onsi/ginkgo/v2" @@ -64,7 +64,8 @@ var nodePool *karpv1.NodePool var nodeClass *v1alpha2.AKSNodeClass var cluster *state.Cluster var cloudProvider *cloudprovider.CloudProvider -var garbageCollectionController *garbagecollection.Controller +var virtualMachineGCController *garbagecollection.VirtualMachine +var networkInterfaceGCController *garbagecollection.NetworkInterface var prov *provisioning.Provisioner func TestAPIs(t *testing.T) { @@ -80,7 +81,8 @@ var _ = BeforeSuite(func() { // ctx, stop = context.WithCancel(ctx) azureEnv = test.NewEnvironment(ctx, env) cloudProvider = cloudprovider.New(azureEnv.InstanceTypesProvider, azureEnv.InstanceProvider, events.NewRecorder(&record.FakeRecorder{}), env.Client, azureEnv.ImageProvider) - garbageCollectionController = garbagecollection.NewController(env.Client, cloudProvider) + virtualMachineGCController = garbagecollection.NewVirtualMachine(env.Client, cloudProvider) + networkInterfaceGCController = garbagecollection.NewNetworkInterface(env.Client, azureEnv.InstanceProvider) fakeClock = &clock.FakeClock{} cluster = state.NewCluster(fakeClock, env.Client) prov = provisioning.NewProvisioner(env.Client, events.NewRecorder(&record.FakeRecorder{}), cloudProvider, cluster, fakeClock) @@ -119,7 +121,7 @@ var _ = AfterEach(func() { // TODO: move before/after each into the tests (see AWS) // review tests themselves (very different from AWS?) // (e.g. AWS has not a single ExpectPRovisioned? why?) -var _ = Describe("GarbageCollection", func() { +var _ = Describe("VirtualMachine Garbage Collection", func() { var vm *armcompute.VirtualMachine var providerID string var err error @@ -147,7 +149,7 @@ var _ = Describe("GarbageCollection", func() { }) azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).NotTo(HaveOccurred()) }) @@ -164,23 +166,18 @@ var _ = Describe("GarbageCollection", func() { vm, err = azureEnv.InstanceProvider.Get(ctx, vmName) Expect(err).To(BeNil()) providerID = utils.ResourceIDToProviderID(ctx, *vm.ID) - azureEnv.VirtualMachinesAPI.Instances.Store( - *vm.ID, - armcompute.VirtualMachine{ - ID: vm.ID, - Name: vm.Name, - Location: lo.ToPtr(fake.Region), - Properties: &armcompute.VirtualMachineProperties{ - TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), - }, - Tags: map[string]*string{ - instance.NodePoolTagKey: lo.ToPtr("default"), - }, - }) + newVM := test.VirtualMachine(test.VirtualMachineOptions{ + Name: vmName, + NodepoolName: "default", + Properties: &armcompute.VirtualMachineProperties{ + TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), + }, + }) + azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(newVM.ID), newVM) ids = append(ids, *vm.ID) } } - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) wg := sync.WaitGroup{} for _, id := range ids { @@ -210,19 +207,14 @@ var _ = Describe("GarbageCollection", func() { vm, err = azureEnv.InstanceProvider.Get(ctx, vmName) Expect(err).To(BeNil()) providerID = utils.ResourceIDToProviderID(ctx, *vm.ID) - azureEnv.VirtualMachinesAPI.Instances.Store( - *vm.ID, - armcompute.VirtualMachine{ - ID: vm.ID, - Name: vm.Name, - Location: lo.ToPtr(fake.Region), - Properties: &armcompute.VirtualMachineProperties{ - TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), - }, - Tags: map[string]*string{ - instance.NodePoolTagKey: lo.ToPtr("default"), - }, - }) + newVM := test.VirtualMachine(test.VirtualMachineOptions{ + Name: vmName, + NodepoolName: "default", + Properties: &armcompute.VirtualMachineProperties{ + TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), + }, + }) + azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(newVM.ID), newVM) nodeClaim := coretest.NodeClaim(karpv1.NodeClaim{ Status: karpv1.NodeClaimStatus{ ProviderID: utils.ResourceIDToProviderID(ctx, *vm.ID), @@ -233,7 +225,7 @@ var _ = Describe("GarbageCollection", func() { nodeClaims = append(nodeClaims, nodeClaim) } } - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) wg := sync.WaitGroup{} for _, id := range ids { @@ -259,7 +251,7 @@ var _ = Describe("GarbageCollection", func() { } azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).NotTo(HaveOccurred()) }) @@ -280,7 +272,7 @@ var _ = Describe("GarbageCollection", func() { }) ExpectApplied(ctx, env.Client, nodeClaim, node) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).ToNot(HaveOccurred()) ExpectExists(ctx, env.Client, node) @@ -289,16 +281,8 @@ var _ = Describe("GarbageCollection", func() { var _ = Context("Basic", func() { BeforeEach(func() { - id := utils.MkVMID(azureEnv.AzureResourceGraphAPI.ResourceGroup, "vm-a") - vm = &armcompute.VirtualMachine{ - ID: lo.ToPtr(id), - Name: lo.ToPtr("vm-a"), - Location: lo.ToPtr(fake.Region), - Tags: map[string]*string{ - instance.NodePoolTagKey: lo.ToPtr("default"), - }, - } - providerID = utils.ResourceIDToProviderID(ctx, id) + vm = test.VirtualMachine(test.VirtualMachineOptions{Name: "vm-a", NodepoolName: "default"}) + providerID = utils.ResourceIDToProviderID(ctx, lo.FromPtr(vm.ID)) }) It("should delete an instance if there is no NodeClaim owner", func() { // Launch happened 10m ago @@ -307,7 +291,7 @@ var _ = Describe("GarbageCollection", func() { } azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err = cloudProvider.Get(ctx, providerID) Expect(err).To(HaveOccurred()) Expect(corecloudprovider.IsNodeClaimNotFoundError(err)).To(BeTrue()) @@ -323,7 +307,7 @@ var _ = Describe("GarbageCollection", func() { }) ExpectApplied(ctx, env.Client, node) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err = cloudProvider.Get(ctx, providerID) Expect(err).To(HaveOccurred()) Expect(corecloudprovider.IsNodeClaimNotFoundError(err)).To(BeTrue()) @@ -332,3 +316,77 @@ var _ = Describe("GarbageCollection", func() { }) }) }) + +var _ = Describe("NetworkInterface Garbage Collection", func() { + It("should not delete a network interface if a nodeclaim exists for it", func() { + // Create and apply a NodeClaim that references this NIC + nodeClaim := coretest.NodeClaim() + ExpectApplied(ctx, env.Client, nodeClaim) + + // Create a managed NIC + nic := test.Interface(test.InterfaceOptions{Name: instance.GenerateResourceName(nodeClaim.Name), NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(nic.ID), *nic) + + nicsBeforeGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsBeforeGC)).To(Equal(1)) + + // Run garbage collection + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + + // Verify NIC still exists after GC + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(1)) + }) + It("should delete a NIC if there is no associated VM", func() { + nic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + nic2 := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(nic.ID), *nic) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(nic2.ID), *nic2) + nicsBeforeGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsBeforeGC)).To(Equal(2)) + // add a nic to azure env, and call reconcile. It should show up in the list before reconcile + // then it should not showup after + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(0)) + }) + It("should not delete a NIC if there is an associated VM", func() { + managedNic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(managedNic.ID), *managedNic) + managedVM := test.VirtualMachine(test.VirtualMachineOptions{Name: lo.FromPtr(managedNic.Name), NodepoolName: nodePool.Name}) + azureEnv.VirtualMachinesAPI.VirtualMachinesBehavior.Instances.Store(lo.FromPtr(managedVM.ID), *managedVM) + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + // We should still have a network interface here + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(1)) + + }) + It("the vm gc controller should remove the nic if there is an associated vm", func() { + managedNic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(managedNic.ID), *managedNic) + managedVM := test.VirtualMachine(test.VirtualMachineOptions{ + Name: lo.FromPtr(managedNic.Name), + NodepoolName: nodePool.Name, + Properties: &armcompute.VirtualMachineProperties{ + TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 16)), // Needs to be older than the nodeclaim registration ttl + }, + }) + azureEnv.VirtualMachinesAPI.VirtualMachinesBehavior.Instances.Store(lo.FromPtr(managedVM.ID), *managedVM) + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + // We should still have a network interface here + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(1)) + + ExpectSingletonReconciled(ctx, virtualMachineGCController) + nicsAfterVMReconciliation, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterVMReconciliation)).To(Equal(0)) + + }) +}) diff --git a/pkg/fake/azureresourcegraphapi.go b/pkg/fake/azureresourcegraphapi.go index fe160ae62..9eea36287 100644 --- a/pkg/fake/azureresourcegraphapi.go +++ b/pkg/fake/azureresourcegraphapi.go @@ -23,6 +23,7 @@ import ( "github.com/samber/lo" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" ) @@ -35,6 +36,7 @@ type AzureResourceGraphResourcesInput struct { type AzureResourceGraphBehavior struct { AzureResourceGraphResourcesBehavior MockedFunction[AzureResourceGraphResourcesInput, armresourcegraph.ClientResourcesResponse] VirtualMachinesAPI *VirtualMachinesAPI + NetworkInterfacesAPI *NetworkInterfacesAPI ResourceGroup string } @@ -42,9 +44,23 @@ type AzureResourceGraphBehavior struct { var _ instance.AzureResourceGraphAPI = &AzureResourceGraphAPI{} type AzureResourceGraphAPI struct { + vmListQuery string + nicListQuery string AzureResourceGraphBehavior } +func NewAzureResourceGraphAPI(resourceGroup string, virtualMachinesAPI *VirtualMachinesAPI, networkInterfacesAPI *NetworkInterfacesAPI) *AzureResourceGraphAPI { + return &AzureResourceGraphAPI{ + vmListQuery: instance.GetVMListQueryBuilder(resourceGroup).String(), + nicListQuery: instance.GetNICListQueryBuilder(resourceGroup).String(), + AzureResourceGraphBehavior: AzureResourceGraphBehavior{ + VirtualMachinesAPI: virtualMachinesAPI, + NetworkInterfacesAPI: networkInterfacesAPI, + ResourceGroup: resourceGroup, + }, + } +} + // Reset must be called between tests otherwise tests will pollute each other. func (c *AzureResourceGraphAPI) Reset() {} @@ -66,7 +82,7 @@ func (c *AzureResourceGraphAPI) Resources(_ context.Context, query armresourcegr func (c *AzureResourceGraphAPI) getResourceList(query string) []interface{} { switch query { - case instance.GetListQueryBuilder(c.ResourceGroup).String(): + case c.vmListQuery: vmList := lo.Filter(c.loadVMObjects(), func(vm armcompute.VirtualMachine, _ int) bool { return vm.Tags != nil && vm.Tags[instance.NodePoolTagKey] != nil }) @@ -75,12 +91,20 @@ func (c *AzureResourceGraphAPI) getResourceList(query string) []interface{} { return convertBytesToInterface(b) }) return resourceList + case c.nicListQuery: + nicList := lo.Filter(c.loadNicObjects(), func(nic armnetwork.Interface, _ int) bool { + return nic.Tags != nil && nic.Tags[instance.NodePoolTagKey] != nil + }) + resourceList := lo.Map(nicList, func(nic armnetwork.Interface, _ int) interface{} { + b, _ := json.Marshal(nic) + return convertBytesToInterface(b) + }) + return resourceList } return nil } -func (c *AzureResourceGraphAPI) loadVMObjects() []armcompute.VirtualMachine { - vmList := []armcompute.VirtualMachine{} +func (c *AzureResourceGraphAPI) loadVMObjects() (vmList []armcompute.VirtualMachine) { c.VirtualMachinesAPI.Instances.Range(func(k, v any) bool { vm, _ := c.VirtualMachinesAPI.Instances.Load(k) vmList = append(vmList, vm.(armcompute.VirtualMachine)) @@ -89,6 +113,15 @@ func (c *AzureResourceGraphAPI) loadVMObjects() []armcompute.VirtualMachine { return vmList } +func (c *AzureResourceGraphAPI) loadNicObjects() (nicList []armnetwork.Interface) { + c.NetworkInterfacesAPI.NetworkInterfaces.Range(func(k, v any) bool { + nic, _ := c.NetworkInterfacesAPI.NetworkInterfaces.Load(k) + nicList = append(nicList, nic.(armnetwork.Interface)) + return true + }) + return nicList +} + func convertBytesToInterface(b []byte) interface{} { jsonObj := instance.Resource{} _ = json.Unmarshal(b, &jsonObj) diff --git a/pkg/fake/azureresourcegraphapi_test.go b/pkg/fake/azureresourcegraphapi_test.go index f23bc2796..f5bc7b1ef 100644 --- a/pkg/fake/azureresourcegraphapi_test.go +++ b/pkg/fake/azureresourcegraphapi_test.go @@ -32,7 +32,7 @@ func TestAzureResourceGraphAPI_Resources_VM(t *testing.T) { resourceGroup := "test_managed_cluster_rg" subscriptionID := "test_sub" virtualMachinesAPI := &VirtualMachinesAPI{} - azureResourceGraphAPI := &AzureResourceGraphAPI{AzureResourceGraphBehavior{VirtualMachinesAPI: virtualMachinesAPI, ResourceGroup: resourceGroup}} + azureResourceGraphAPI := NewAzureResourceGraphAPI(resourceGroup, virtualMachinesAPI, nil) cases := []struct { testName string vmNames []string @@ -67,7 +67,7 @@ func TestAzureResourceGraphAPI_Resources_VM(t *testing.T) { return } } - queryRequest := instance.NewQueryRequest(&subscriptionID, instance.GetListQueryBuilder(resourceGroup).String()) + queryRequest := instance.NewQueryRequest(&subscriptionID, instance.GetVMListQueryBuilder(resourceGroup).String()) data, err := instance.GetResourceData(context.Background(), azureResourceGraphAPI, *queryRequest) if err != nil { t.Errorf("Unexpected error %v", err) diff --git a/pkg/fake/networkinterfaceapi.go b/pkg/fake/networkinterfaceapi.go index f1fed7163..96404fdff 100644 --- a/pkg/fake/networkinterfaceapi.go +++ b/pkg/fake/networkinterfaceapi.go @@ -73,6 +73,7 @@ func (c *NetworkInterfacesAPI) BeginCreateOrUpdate(_ context.Context, resourceGr return c.NetworkInterfacesCreateOrUpdateBehavior.Invoke(input, func(input *NetworkInterfaceCreateOrUpdateInput) (*armnetwork.InterfacesClientCreateOrUpdateResponse, error) { iface := input.Interface + iface.Name = to.StringPtr(input.InterfaceName) id := mkNetworkInterfaceID(input.ResourceGroupName, input.InterfaceName) iface.ID = to.StringPtr(id) c.NetworkInterfaces.Store(id, iface) @@ -99,7 +100,7 @@ func (c *NetworkInterfacesAPI) BeginDelete(_ context.Context, resourceGroupName InterfaceName: interfaceName, } return c.NetworkInterfacesDeleteBehavior.Invoke(input, func(input *NetworkInterfaceDeleteInput) (*armnetwork.InterfacesClientDeleteResponse, error) { - id := mkNetworkInterfaceID(resourceGroupName, interfaceName) + id := mkNetworkInterfaceID(input.ResourceGroupName, input.InterfaceName) c.NetworkInterfaces.Delete(id) return &armnetwork.InterfacesClientDeleteResponse{}, nil }) diff --git a/pkg/providers/instance/azureresourcegraphlist.go b/pkg/providers/instance/azureresourcegraphlist.go new file mode 100644 index 000000000..fc41fd5f0 --- /dev/null +++ b/pkg/providers/instance/azureresourcegraphlist.go @@ -0,0 +1,108 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package instance + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/Azure/azure-kusto-go/kusto/kql" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/samber/lo" +) + +const ( + vmResourceType = "microsoft.compute/virtualmachines" + nicResourceType = "microsoft.network/networkinterfaces" +) + +// getResourceListQueryBuilder returns a KQL query builder for listing resources with nodepool tags +func getResourceListQueryBuilder(rg string, resourceType string) *kql.Builder { + return kql.New(`Resources`). + AddLiteral(` | where type == `).AddString(resourceType). + AddLiteral(` | where resourceGroup == `).AddString(strings.ToLower(rg)). // ARG resources appear to have lowercase RG + AddLiteral(` | where tags has_cs `).AddString(NodePoolTagKey) +} + +// GetVMListQueryBuilder returns a KQL query builder for listing VMs with nodepool tags +func GetVMListQueryBuilder(rg string) *kql.Builder { + return getResourceListQueryBuilder(rg, vmResourceType) +} + +// GetNICListQueryBuilder returns a KQL query builder for listing NICs with nodepool tags +func GetNICListQueryBuilder(rg string) *kql.Builder { + return getResourceListQueryBuilder(rg, nicResourceType) +} + +// createVMFromQueryResponseData converts ARG query response data into a VirtualMachine object +func createVMFromQueryResponseData(data map[string]interface{}) (*armcompute.VirtualMachine, error) { + jsonString, err := json.Marshal(data) + if err != nil { + return nil, err + } + vm := armcompute.VirtualMachine{} + err = json.Unmarshal(jsonString, &vm) + if err != nil { + return nil, err + } + if vm.ID == nil { + return nil, fmt.Errorf("virtual machine is missing id") + } + if vm.Name == nil { + return nil, fmt.Errorf("virtual machine is missing name") + } + if vm.Tags == nil { + return nil, fmt.Errorf("virtual machine is missing tags") + } + // We see inconsistent casing being returned by ARG for the last segment + // of the vm.ID string. This forces it to be lowercase. + parts := strings.Split(lo.FromPtr(vm.ID), "/") + parts[len(parts)-1] = strings.ToLower(parts[len(parts)-1]) + vm.ID = lo.ToPtr(strings.Join(parts, "/")) + return &vm, nil +} + +// createNICFromQueryResponseData converts ARG query response data into a Network Interface object +func createNICFromQueryResponseData(data map[string]interface{}) (*armnetwork.Interface, error) { + jsonString, err := json.Marshal(data) + if err != nil { + return nil, err + } + + nic := armnetwork.Interface{} + err = json.Unmarshal(jsonString, &nic) + if err != nil { + return nil, err + } + if nic.ID == nil { + return nil, fmt.Errorf("network interface is missing id") + } + if nic.Name == nil { + return nil, fmt.Errorf("network interface is missing name") + } + if nic.Tags == nil { + return nil, fmt.Errorf("network interface is missing tags") + } + // We see inconsistent casing being returned by ARG for the last segment + // of the nic.ID string. This forces it to be lowercase. + parts := strings.Split(lo.FromPtr(nic.ID), "/") + parts[len(parts)-1] = strings.ToLower(parts[len(parts)-1]) + nic.ID = lo.ToPtr(strings.Join(parts, "/")) + return &nic, nil +} diff --git a/pkg/providers/instance/argutils.go b/pkg/providers/instance/azureresourcegraphutils.go similarity index 100% rename from pkg/providers/instance/argutils.go rename to pkg/providers/instance/azureresourcegraphutils.go diff --git a/pkg/providers/instance/armutils.go b/pkg/providers/instance/azureresourcemanagerutils.go similarity index 100% rename from pkg/providers/instance/armutils.go rename to pkg/providers/instance/azureresourcemanagerutils.go diff --git a/pkg/providers/instance/instance.go b/pkg/providers/instance/instance.go index 975eb2a92..b8c2fb77f 100644 --- a/pkg/providers/instance/instance.go +++ b/pkg/providers/instance/instance.go @@ -18,7 +18,6 @@ package instance import ( "context" - "encoding/json" "errors" "fmt" "math" @@ -32,7 +31,6 @@ import ( "k8s.io/apimachinery/pkg/util/sets" "knative.dev/pkg/logging" - "github.com/Azure/azure-kusto-go/kusto/kql" "github.com/Azure/karpenter-provider-azure/pkg/cache" "github.com/Azure/karpenter-provider-azure/pkg/providers/instancetype" "github.com/Azure/karpenter-provider-azure/pkg/providers/launchtemplate" @@ -55,9 +53,13 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" ) +var ( + vmListQuery string + nicListQuery string +) + var ( NodePoolTagKey = strings.ReplaceAll(karpv1.NodePoolLabelKey, "/", "_") - listQuery string CapacityTypeToPriority = map[string]string{ karpv1.CapacityTypeSpot: string(compute.Spot), @@ -87,6 +89,8 @@ type Provider interface { // CreateTags(context.Context, string, map[string]string) error Update(context.Context, string, armcompute.VirtualMachineUpdate) error GetNic(context.Context, string, string) (*armnetwork.Interface, error) + DeleteNic(context.Context, string) error + ListNics(context.Context) ([]*armnetwork.Interface, error) } // assert that DefaultProvider implements Provider interface @@ -115,7 +119,8 @@ func NewDefaultProvider( subscriptionID string, provisionMode string, ) *DefaultProvider { - listQuery = GetListQueryBuilder(resourceGroup).String() + vmListQuery = GetVMListQueryBuilder(resourceGroup).String() + nicListQuery = GetNICListQueryBuilder(resourceGroup).String() return &DefaultProvider{ azClient: azClient, instanceTypeProvider: instanceTypeProvider, @@ -175,7 +180,7 @@ func (p *DefaultProvider) Get(ctx context.Context, vmName string) (*armcompute.V } func (p *DefaultProvider) List(ctx context.Context) ([]*armcompute.VirtualMachine, error) { - req := NewQueryRequest(&(p.subscriptionID), listQuery) + req := NewQueryRequest(&(p.subscriptionID), vmListQuery) client := p.azClient.azureResourceGraphClient data, err := GetResourceData(ctx, client, *req) if err != nil { @@ -197,6 +202,37 @@ func (p *DefaultProvider) Delete(ctx context.Context, resourceName string) error return p.cleanupAzureResources(ctx, resourceName) } +func (p *DefaultProvider) GetNic(ctx context.Context, rg, nicName string) (*armnetwork.Interface, error) { + nicResponse, err := p.azClient.networkInterfacesClient.Get(ctx, rg, nicName, nil) + if err != nil { + return nil, err + } + return &nicResponse.Interface, nil +} + +// ListNics returns all network interfaces in the resource group that have the nodepool tag +func (p *DefaultProvider) ListNics(ctx context.Context) ([]*armnetwork.Interface, error) { + req := NewQueryRequest(&(p.subscriptionID), nicListQuery) + client := p.azClient.azureResourceGraphClient + data, err := GetResourceData(ctx, client, *req) + if err != nil { + return nil, fmt.Errorf("querying azure resource graph, %w", err) + } + var nicList []*armnetwork.Interface + for i := range data { + nic, err := createNICFromQueryResponseData(data[i]) + if err != nil { + return nil, fmt.Errorf("creating NIC object from query response data, %w", err) + } + nicList = append(nicList, nic) + } + return nicList, nil +} + +func (p *DefaultProvider) DeleteNic(ctx context.Context, nicName string) error { + return deleteNicIfExists(ctx, p.azClient.networkInterfacesClient, p.resourceGroup, nicName) +} + // createAKSIdentifyingExtension attaches a VM extension to identify that this VM participates in an AKS cluster func (p *DefaultProvider) createAKSIdentifyingExtension(ctx context.Context, vmName string) (err error) { vmExt := p.getAKSIdentifyingExtension() @@ -302,14 +338,6 @@ func (p *DefaultProvider) createNetworkInterface(ctx context.Context, opts *crea return *res.ID, nil } -func (p *DefaultProvider) GetNic(ctx context.Context, rg, nicName string) (*armnetwork.Interface, error) { - nicResponse, err := p.azClient.networkInterfacesClient.Get(ctx, rg, nicName, nil) - if err != nil { - return nil, err - } - return &nicResponse.Interface, nil -} - // newVMObject is a helper func that creates a new armcompute.VirtualMachine // from key input. func newVMObject( @@ -642,11 +670,11 @@ func (p *DefaultProvider) cleanupAzureResources(ctx context.Context, resourceNam // The order here is intentional, if the VM was created successfully, then we attempt to delete the vm, the // nic, disk and all associated resources will be removed. If the VM was not created successfully and a nic was found, // then we attempt to delete the nic. + nicErr := deleteNicIfExists(ctx, p.azClient.networkInterfacesClient, p.resourceGroup, resourceName) if nicErr != nil { - logging.FromContext(ctx).Errorf("networkInterface.Delete for %s failed: %v", resourceName, nicErr) + logging.FromContext(ctx).Errorf("networkinterface.Delete for %s failed: %v", resourceName, nicErr) } - return errors.Join(vmErr, nicErr) } @@ -748,40 +776,6 @@ func (p *DefaultProvider) getCSExtension(cse string, isWindows bool) *armcompute } } -func GetListQueryBuilder(rg string) *kql.Builder { - return kql.New(`Resources`). - AddLiteral(` | where type == "microsoft.compute/virtualmachines"`). - AddLiteral(` | where resourceGroup == `).AddString(strings.ToLower(rg)). // ARG VMs appear to have lowercase RG - AddLiteral(` | where tags has_cs `).AddString(NodePoolTagKey) -} - -func createVMFromQueryResponseData(data map[string]interface{}) (*armcompute.VirtualMachine, error) { - jsonString, err := json.Marshal(data) - if err != nil { - return nil, err - } - vm := armcompute.VirtualMachine{} - err = json.Unmarshal(jsonString, &vm) - if err != nil { - return nil, err - } - if vm.ID == nil { - return nil, fmt.Errorf("virtual machine is missing id") - } - if vm.Name == nil { - return nil, fmt.Errorf("virtual machine is missing name") - } - if vm.Tags == nil { - return nil, fmt.Errorf("virtual machine is missing tags") - } - // We see inconsistent casing being returned by ARG for the last segment - // of the vm.ID string. This forces it to be lowercase. - parts := strings.Split(lo.FromPtr(vm.ID), "/") - parts[len(parts)-1] = strings.ToLower(parts[len(parts)-1]) - vm.ID = lo.ToPtr(strings.Join(parts, "/")) - return &vm, nil -} - func ConvertToVirtualMachineIdentity(nodeIdentities []string) *armcompute.VirtualMachineIdentity { var identity *armcompute.VirtualMachineIdentity if len(nodeIdentities) > 0 { diff --git a/pkg/providers/instance/instance_test.go b/pkg/providers/instance/instance_test.go index 06710aa79..0c08e6578 100644 --- a/pkg/providers/instance/instance_test.go +++ b/pkg/providers/instance/instance_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/karpenter-provider-azure/pkg/cache" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" @@ -100,6 +101,67 @@ func TestGetPriorityCapacityAndInstanceType(t *testing.T) { } } +func TestCreateNICFromQueryResponseData(t *testing.T) { + id := "nic_id" + name := "nic_name" + tag := "tag1" + val := "val1" + tags := map[string]*string{tag: &val} + + tc := []struct { + testName string + data map[string]interface{} + expectedError string + expectedNIC *armnetwork.Interface + }{ + { + testName: "missing id", + data: map[string]interface{}{ + "name": name, + }, + expectedError: "network interface is missing id", + expectedNIC: nil, + }, + { + testName: "missing name", + data: map[string]interface{}{ + "id": id, + }, + expectedError: "network interface is missing name", + expectedNIC: nil, + }, + { + testName: "happy case", + data: map[string]interface{}{ + "id": id, + "name": name, + "tags": map[string]interface{}{tag: val}, + }, + expectedNIC: &armnetwork.Interface{ + ID: &id, + Name: &name, + Tags: tags, + }, + }, + } + + for _, c := range tc { + nic, err := createNICFromQueryResponseData(c.data) + if nic != nil { + expected := *c.expectedNIC + actual := *nic + assert.Equal(t, *expected.ID, *actual.ID, c.testName) + assert.Equal(t, *expected.Name, *actual.Name, c.testName) + for key := range expected.Tags { + assert.Equal(t, *(expected.Tags[key]), *(actual.Tags[key]), c.testName) + } + } + if err != nil { + assert.Equal(t, c.expectedError, err.Error(), c.testName) + } + } +} + // Currently tested: ID, Name, Tags, Zones // TODO: Add the below attributes for Properties if needed: // Priority, InstanceView.HyperVGeneration, TimeCreated diff --git a/pkg/providers/instance/suite_test.go b/pkg/providers/instance/suite_test.go index 923f24e17..6ef0dc1ec 100644 --- a/pkg/providers/instance/suite_test.go +++ b/pkg/providers/instance/suite_test.go @@ -30,8 +30,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" clock "k8s.io/utils/clock/testing" - "k8s.io/client-go/tools/record" - "github.com/Azure/karpenter-provider-azure/pkg/apis" "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" @@ -39,6 +37,8 @@ import ( "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" "github.com/Azure/karpenter-provider-azure/pkg/test" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/karpenter/pkg/controllers/provisioning" "sigs.k8s.io/karpenter/pkg/controllers/state" "sigs.k8s.io/karpenter/pkg/events" @@ -46,6 +46,7 @@ import ( karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" corecloudprovider "sigs.k8s.io/karpenter/pkg/cloudprovider" + . "github.com/Azure/karpenter-provider-azure/pkg/test/expectations" . "knative.dev/pkg/logging/testing" . "sigs.k8s.io/karpenter/pkg/test/expectations" "sigs.k8s.io/karpenter/pkg/test/v1alpha1" @@ -215,4 +216,24 @@ var _ = Describe("InstanceProvider", func() { return strings.Contains(key, "/") // ARM tags can't contain '/' })).To(HaveLen(0)) }) + It("should list nic from karpenter provisioning request", func() { + ExpectApplied(ctx, env.Client, nodePool, nodeClass) + pod := coretest.UnschedulablePod(coretest.PodOptions{}) + ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) + ExpectScheduled(ctx, env.Client, pod) + interfaces, err := azureEnv.InstanceProvider.ListNics(ctx) + Expect(err).To(BeNil()) + Expect(len(interfaces)).To(Equal(1)) + }) + It("should only list nics that belong to karpenter", func() { + managedNic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + unmanagedNic := test.Interface(test.InterfaceOptions{Tags: map[string]*string{"kubernetes.io/cluster/test-cluster": lo.ToPtr("random-aks-vm")}}) + + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(managedNic.ID), *managedNic) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(unmanagedNic.ID), *unmanagedNic) + interfaces, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(interfaces)).To(Equal(1)) + Expect(interfaces[0].Name).To(Equal(managedNic.Name)) + }) }) diff --git a/pkg/test/environment.go b/pkg/test/environment.go index 11d3faf0e..bcf8ab48f 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -92,15 +92,16 @@ func NewRegionalEnvironment(ctx context.Context, env *coretest.Environment, regi // API virtualMachinesAPI := &fake.VirtualMachinesAPI{} - azureResourceGraphAPI := &fake.AzureResourceGraphAPI{AzureResourceGraphBehavior: fake.AzureResourceGraphBehavior{VirtualMachinesAPI: virtualMachinesAPI, ResourceGroup: resourceGroup}} - virtualMachinesExtensionsAPI := &fake.VirtualMachineExtensionsAPI{} + networkInterfacesAPI := &fake.NetworkInterfacesAPI{} + virtualMachinesExtensionsAPI := &fake.VirtualMachineExtensionsAPI{} pricingAPI := &fake.PricingAPI{} skuClientSingleton := &fake.MockSkuClientSingleton{SKUClient: &fake.ResourceSKUsAPI{Location: region}} communityImageVersionsAPI := &fake.CommunityGalleryImageVersionsAPI{} loadBalancersAPI := &fake.LoadBalancersAPI{} nodeImageVersionsAPI := &fake.NodeImageVersionsAPI{} + azureResourceGraphAPI := fake.NewAzureResourceGraphAPI(resourceGroup, virtualMachinesAPI, networkInterfacesAPI) // Cache kubernetesVersionCache := cache.New(azurecache.KubernetesVersionTTL, azurecache.DefaultCleanupInterval) instanceTypeCache := cache.New(instancetype.InstanceTypesCacheTTL, azurecache.DefaultCleanupInterval) diff --git a/pkg/test/expectations/expectations.go b/pkg/test/expectations/expectations.go index 5184d2d27..d6f1e5632 100644 --- a/pkg/test/expectations/expectations.go +++ b/pkg/test/expectations/expectations.go @@ -21,10 +21,9 @@ import ( "fmt" "strings" + "github.com/Azure/karpenter-provider-azure/pkg/test" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - - "github.com/Azure/karpenter-provider-azure/pkg/test" ) func ExpectUnavailable(env *test.Environment, instanceType string, zone string, capacityType string) { @@ -54,3 +53,8 @@ func ExpectDecodedCustomData(env *test.Environment) string { return decodedString } + +func ExpectNoError(err error) { + GinkgoHelper() + Expect(err).To(BeNil()) +} diff --git a/pkg/test/networkinterfaces.go b/pkg/test/networkinterfaces.go new file mode 100644 index 000000000..c879a7495 --- /dev/null +++ b/pkg/test/networkinterfaces.go @@ -0,0 +1,83 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/Azure/karpenter-provider-azure/pkg/fake" + "github.com/imdario/mergo" + "github.com/samber/lo" +) + +// InterfaceOptions customizes an Azure Network Interface for testing. +type InterfaceOptions struct { + Name string + NodepoolName string + Location string + Properties *armnetwork.InterfacePropertiesFormat + Tags map[string]*string +} + +// Interface creates a test Azure Network Interface with defaults that can be overridden by InterfaceOptions. +// Overrides are applied in order, with last-write-wins semantics. +func Interface(overrides ...InterfaceOptions) *armnetwork.Interface { + options := InterfaceOptions{} + for _, o := range overrides { + if err := mergo.Merge(&options, o, mergo.WithOverride); err != nil { + panic(fmt.Sprintf("Failed to merge Interface options: %s", err)) + } + } + + // Provide default values if none are set + if options.Name == "" { + options.Name = RandomName("aks") + } + if options.NodepoolName == "" { + options.NodepoolName = "default" + } + if options.Location == "" { + options.Location = fake.Region + } + if options.Tags == nil { + options.Tags = ManagedTags(options.NodepoolName) + } + if options.Properties == nil { + options.Properties = &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ + { + Name: lo.ToPtr("ipConfig"), + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: lo.ToPtr(armnetwork.IPAllocationMethodDynamic), + Subnet: &armnetwork.Subnet{ID: lo.ToPtr("/subscriptions/.../resourceGroups/.../providers/Microsoft.Network/virtualNetworks/.../subnets/default")}, + }, + }, + }, + } + } + + nic := &armnetwork.Interface{ + ID: lo.ToPtr(fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/test-resourceGroup/providers/Microsoft.Network/networkInterfaces/%s", options.Name)), + Name: &options.Name, + Location: &options.Location, + Properties: options.Properties, + Tags: options.Tags, + } + + return nic +} diff --git a/pkg/test/utils.go b/pkg/test/utils.go new file mode 100644 index 000000000..1c1af00f9 --- /dev/null +++ b/pkg/test/utils.go @@ -0,0 +1,35 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "github.com/samber/lo" + k8srand "k8s.io/apimachinery/pkg/util/rand" +) + +// RandomName returns a pseudo-random resource name with a given prefix. +func RandomName(prefix string) string { + // You could make this more robust by including additional random characters. + return prefix + "-" + k8srand.String(10) +} + +func ManagedTags(nodepoolName string) map[string]*string { + return map[string]*string{ + "karpenter.sh_cluster": lo.ToPtr("test-cluster"), + "karpenter.sh_nodepool": lo.ToPtr(nodepoolName), + } +} diff --git a/pkg/test/virtualmachines.go b/pkg/test/virtualmachines.go new file mode 100644 index 000000000..7ad01321a --- /dev/null +++ b/pkg/test/virtualmachines.go @@ -0,0 +1,78 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/karpenter-provider-azure/pkg/fake" + "github.com/imdario/mergo" + "github.com/samber/lo" +) + +// VirtualMachineOptions customizes an Azure Virtual Machine for testing. +type VirtualMachineOptions struct { + Name string + NodepoolName string + Location string + Properties *armcompute.VirtualMachineProperties + Tags map[string]*string +} + +// VirtualMachine creates a test Azure Virtual Machine with defaults that can be overridden by VirtualMachineOptions. +// Overrides are applied in order, with last-write-wins semantics. +func VirtualMachine(overrides ...VirtualMachineOptions) *armcompute.VirtualMachine { + options := VirtualMachineOptions{} + for _, o := range overrides { + if err := mergo.Merge(&options, o, mergo.WithOverride); err != nil { + panic(fmt.Sprintf("Failed to merge VirtualMachine options: %s", err)) + } + } + + // Provide default values if none are set + if options.Name == "" { + options.Name = RandomName("aks") + } + if options.NodepoolName == "" { + options.NodepoolName = "default" + } + if options.Location == "" { + options.Location = fake.Region + } + if options.Properties == nil { + options.Properties = &armcompute.VirtualMachineProperties{} + } + if options.Tags == nil { + options.Tags = ManagedTags(options.NodepoolName) + } + if options.Properties.TimeCreated == nil { + options.Properties.TimeCreated = lo.ToPtr(time.Now()) + } + + // Construct the basic VM + vm := &armcompute.VirtualMachine{ + ID: lo.ToPtr(fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/test-resourceGroup/providers/Microsoft.Compute/virtualMachines/%s", options.Name)), + Name: lo.ToPtr(options.Name), + Location: lo.ToPtr(options.Location), + Properties: options.Properties, + Tags: options.Tags, + } + + return vm +}