diff --git a/pkg/frontend/quota_test.go b/pkg/frontend/quota_test.go index b0d4c6ba577..226c34642e5 100644 --- a/pkg/frontend/quota_test.go +++ b/pkg/frontend/quota_test.go @@ -7,14 +7,14 @@ import ( "context" "testing" + sdknetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute" - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" "github.com/Azure/go-autorest/autorest/to" "go.uber.org/mock/gomock" "github.com/Azure/ARO-RP/pkg/api" + mock_armnetwork "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/azuresdk/armnetwork" mock_compute "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/compute" - mock_network "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/network" utilerror "github.com/Azure/ARO-RP/test/util/error" ) @@ -23,13 +23,13 @@ func TestValidateQuota(t *testing.T) { type test struct { name string - mocks func(*test, *mock_compute.MockUsageClient, *mock_network.MockUsageClient) + mocks func(*test, *mock_compute.MockUsageClient, *mock_armnetwork.MockUsagesClient) wantErr string } for _, tt := range []*test{ { name: "allow when there's enough resources - limits set to exact requirements, offset by 100 of current value", - mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_network.MockUsageClient) { + mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_armnetwork.MockUsagesClient) { cuc.EXPECT(). List(ctx, "ocLocation"). Return([]mgmtcompute.Usage{ @@ -63,10 +63,10 @@ func TestValidateQuota(t *testing.T) { }, }, nil) nuc.EXPECT(). - List(ctx, "ocLocation"). - Return([]mgmtnetwork.Usage{ + List(ctx, "ocLocation", nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(4), @@ -78,7 +78,7 @@ func TestValidateQuota(t *testing.T) { { name: "not enough cores", wantErr: "400: ResourceQuotaExceeded: : Resource quota of cores exceeded. Maximum allowed: 212, Current in use: 101, Additional requested: 112.", - mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_network.MockUsageClient) { + mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_armnetwork.MockUsagesClient) { cuc.EXPECT(). List(ctx, "ocLocation"). Return([]mgmtcompute.Usage{ @@ -95,7 +95,7 @@ func TestValidateQuota(t *testing.T) { { name: "not enough virtualMachines", wantErr: "400: ResourceQuotaExceeded: : Resource quota of virtualMachines exceeded. Maximum allowed: 114, Current in use: 101, Additional requested: 14.", - mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_network.MockUsageClient) { + mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_armnetwork.MockUsagesClient) { cuc.EXPECT(). List(ctx, "ocLocation"). Return([]mgmtcompute.Usage{ @@ -112,7 +112,7 @@ func TestValidateQuota(t *testing.T) { { name: "not enough standardDSv3Family", wantErr: "400: ResourceQuotaExceeded: : Resource quota of standardDSv3Family exceeded. Maximum allowed: 212, Current in use: 101, Additional requested: 112.", - mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_network.MockUsageClient) { + mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_armnetwork.MockUsagesClient) { cuc.EXPECT(). List(ctx, "ocLocation"). Return([]mgmtcompute.Usage{ @@ -129,7 +129,7 @@ func TestValidateQuota(t *testing.T) { { name: "not enough premium disks", wantErr: "400: ResourceQuotaExceeded: : Resource quota of PremiumDiskCount exceeded. Maximum allowed: 114, Current in use: 101, Additional requested: 14.", - mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_network.MockUsageClient) { + mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_armnetwork.MockUsagesClient) { cuc.EXPECT(). List(ctx, "ocLocation"). Return([]mgmtcompute.Usage{ @@ -146,15 +146,15 @@ func TestValidateQuota(t *testing.T) { { name: "not enough public ip addresses", wantErr: "400: ResourceQuotaExceeded: : Resource quota of PublicIPAddresses exceeded. Maximum allowed: 6, Current in use: 4, Additional requested: 3.", - mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_network.MockUsageClient) { + mocks: func(tt *test, cuc *mock_compute.MockUsageClient, nuc *mock_armnetwork.MockUsagesClient) { cuc.EXPECT(). List(ctx, "ocLocation"). Return([]mgmtcompute.Usage{}, nil) nuc.EXPECT(). - List(ctx, "ocLocation"). - Return([]mgmtnetwork.Usage{ + List(ctx, "ocLocation", nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(4), @@ -169,7 +169,7 @@ func TestValidateQuota(t *testing.T) { defer controller.Finish() computeUsageClient := mock_compute.NewMockUsageClient(controller) - networkUsageClient := mock_network.NewMockUsageClient(controller) + networkUsageClient := mock_armnetwork.NewMockUsagesClient(controller) if tt.mocks != nil { tt.mocks(tt, computeUsageClient, networkUsageClient) } diff --git a/pkg/frontend/quota_validation.go b/pkg/frontend/quota_validation.go index a8cd8657ce1..405eb8807a1 100644 --- a/pkg/frontend/quota_validation.go +++ b/pkg/frontend/quota_validation.go @@ -11,8 +11,8 @@ import ( "github.com/Azure/ARO-RP/pkg/api/validate" "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/util/azureclient" + "github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armnetwork" "github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/compute" - "github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network" ) type QuotaValidator interface { @@ -44,13 +44,22 @@ func (q quotaValidator) ValidateQuota(ctx context.Context, azEnv *azureclient.AR return err } + credential, err := environment.FPNewClientCertificateCredential(tenantID) + if err != nil { + return err + } + options := environment.Environment().ArmClientOptions() + spComputeUsage := compute.NewUsageClient(azEnv, subscriptionID, fpAuthorizer) - spNetworkUsage := network.NewUsageClient(azEnv, subscriptionID, fpAuthorizer) + spNetworkUsage, err := armnetwork.NewUsagesClient(subscriptionID, credential, options) + if err != nil { + return err + } return validateQuota(ctx, oc, spNetworkUsage, spComputeUsage) } -func validateQuota(ctx context.Context, oc *api.OpenShiftCluster, spNetworkUsage network.UsageClient, spComputeUsage compute.UsageClient) error { +func validateQuota(ctx context.Context, oc *api.OpenShiftCluster, spNetworkUsage armnetwork.UsagesClient, spComputeUsage compute.UsageClient) error { // If ValidateQuota runs outside install process, we should skip quota validation requiredResources := map[string]int{} @@ -89,7 +98,7 @@ func validateQuota(ctx context.Context, oc *api.OpenShiftCluster, spNetworkUsage } } - netUsages, err := spNetworkUsage.List(ctx, oc.Location) + netUsages, err := spNetworkUsage.List(ctx, oc.Location, nil) if err != nil { return err } diff --git a/pkg/util/azureclient/azuresdk/armnetwork/generate.go b/pkg/util/azureclient/azuresdk/armnetwork/generate.go index 48215a47500..1c78bea86fe 100644 --- a/pkg/util/azureclient/azuresdk/armnetwork/generate.go +++ b/pkg/util/azureclient/azuresdk/armnetwork/generate.go @@ -4,5 +4,5 @@ package armnetwork // Licensed under the Apache License 2.0. //go:generate rm -rf ../../../../util/mocks/$GOPACKAGE -//go:generate mockgen -destination=../../../../util/mocks/azureclient/azuresdk/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/$GOPACKAGE InterfacesClient,LoadBalancersClient,LoadBalancerBackendAddressPoolsClient,PrivateEndpointsClient,PrivateLinkServicesClient,PublicIPAddressesClient,SecurityGroupsClient,SubnetsClient +//go:generate mockgen -destination=../../../../util/mocks/azureclient/azuresdk/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/$GOPACKAGE InterfacesClient,LoadBalancersClient,LoadBalancerBackendAddressPoolsClient,PrivateEndpointsClient,PrivateLinkServicesClient,PublicIPAddressesClient,SecurityGroupsClient,SubnetsClient,UsagesClient,VirtualNetworksClient //go:generate goimports -local=github.com/Azure/ARO-RP -e -w ../../../../util/mocks/azureclient/azuresdk/$GOPACKAGE/$GOPACKAGE.go diff --git a/pkg/util/azureclient/azuresdk/armnetwork/usage.go b/pkg/util/azureclient/azuresdk/armnetwork/usage.go new file mode 100644 index 00000000000..c5fefa06ee7 --- /dev/null +++ b/pkg/util/azureclient/azuresdk/armnetwork/usage.go @@ -0,0 +1,27 @@ +package armnetwork + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" + + "github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/azcore" +) + +// UsagesClient is a minimal interface for azure UsageClient +type UsagesClient interface { + UsageClientAddons +} + +type usagesClient struct { + *armnetwork.UsagesClient +} + +// NewUsagesClient creates a new UsageClient +func NewUsagesClient(subscriptionID string, credential azcore.TokenCredential, options *arm.ClientOptions) (UsagesClient, error) { + client, err := armnetwork.NewUsagesClient(subscriptionID, credential, options) + + return &usagesClient{client}, err +} diff --git a/pkg/util/azureclient/azuresdk/armnetwork/usage_addons.go b/pkg/util/azureclient/azuresdk/armnetwork/usage_addons.go new file mode 100644 index 00000000000..0d5dcea89b5 --- /dev/null +++ b/pkg/util/azureclient/azuresdk/armnetwork/usage_addons.go @@ -0,0 +1,28 @@ +package armnetwork + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" +) + +// UsageClientAddons contains addons to UsageClient +type UsageClientAddons interface { + List(ctx context.Context, location string, options *armnetwork.UsagesClientListOptions) (result []*armnetwork.Usage, err error) +} + +func (c *usagesClient) List(ctx context.Context, location string, options *armnetwork.UsagesClientListOptions) (result []*armnetwork.Usage, err error) { + pager := c.UsagesClient.NewListPager(location, options) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + result = append(result, page.Value...) + } + return result, nil +} diff --git a/pkg/util/azureclient/mgmt/network/generate.go b/pkg/util/azureclient/mgmt/network/generate.go index 8f32bf26402..afd6f558371 100644 --- a/pkg/util/azureclient/mgmt/network/generate.go +++ b/pkg/util/azureclient/mgmt/network/generate.go @@ -4,5 +4,5 @@ package network // Licensed under the Apache License 2.0. //go:generate rm -rf ../../../../util/mocks/$GOPACKAGE -//go:generate mockgen -destination=../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/$GOPACKAGE InterfacesClient,LoadBalancersClient,PrivateEndpointsClient,PublicIPAddressesClient,LoadBalancerBackendAddressPoolsClient,RouteTablesClient,SubnetsClient,VirtualNetworksClient,UsageClient,FlowLogsClient +//go:generate mockgen -destination=../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/$GOPACKAGE InterfacesClient,LoadBalancersClient,PrivateEndpointsClient,PublicIPAddressesClient,LoadBalancerBackendAddressPoolsClient,RouteTablesClient,SubnetsClient,VirtualNetworksClient,FlowLogsClient //go:generate goimports -local=github.com/Azure/ARO-RP -e -w ../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go diff --git a/pkg/util/azureclient/mgmt/network/usage.go b/pkg/util/azureclient/mgmt/network/usage.go deleted file mode 100644 index 72d42d2c1b8..00000000000 --- a/pkg/util/azureclient/mgmt/network/usage.go +++ /dev/null @@ -1,33 +0,0 @@ -package network - -// Copyright (c) Microsoft Corporation. -// Licensed under the Apache License 2.0. - -import ( - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" - "github.com/Azure/go-autorest/autorest" - - "github.com/Azure/ARO-RP/pkg/util/azureclient" -) - -// UsageClient is a minimal interface for azure UsageClient -type UsageClient interface { - UsageClientAddons -} - -type usageClient struct { - mgmtnetwork.UsagesClient -} - -var _ UsageClient = &usageClient{} - -// NewUsageClient creates a new UsageClient -func NewUsageClient(environment *azureclient.AROEnvironment, tenantID string, authorizer autorest.Authorizer) UsageClient { - client := mgmtnetwork.NewUsagesClientWithBaseURI(environment.ResourceManagerEndpoint, tenantID) - client.Authorizer = authorizer - client.Sender = azureclient.DecorateSenderWithLogging(client.Sender) - - return &usageClient{ - UsagesClient: client, - } -} diff --git a/pkg/util/azureclient/mgmt/network/usage_addons.go b/pkg/util/azureclient/mgmt/network/usage_addons.go deleted file mode 100644 index 77d02b800c3..00000000000 --- a/pkg/util/azureclient/mgmt/network/usage_addons.go +++ /dev/null @@ -1,32 +0,0 @@ -package network - -// Copyright (c) Microsoft Corporation. -// Licensed under the Apache License 2.0. - -import ( - "context" - - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" -) - -// UsageClientAddons contains addons to UsageClient -type UsageClientAddons interface { - List(ctx context.Context, location string) (result []mgmtnetwork.Usage, err error) -} - -func (u *usageClient) List(ctx context.Context, location string) (result []mgmtnetwork.Usage, err error) { - page, err := u.UsagesClient.List(ctx, location) - if err != nil { - return nil, err - } - - for page.NotDone() { - result = append(result, page.Values()...) - err = page.NextWithContext(ctx) - if err != nil { - return nil, err - } - } - - return result, nil -} diff --git a/pkg/util/mocks/azureclient/azuresdk/armnetwork/armnetwork.go b/pkg/util/mocks/azureclient/azuresdk/armnetwork/armnetwork.go index a9b96b77b08..618fdea90b4 100644 --- a/pkg/util/mocks/azureclient/azuresdk/armnetwork/armnetwork.go +++ b/pkg/util/mocks/azureclient/azuresdk/armnetwork/armnetwork.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armnetwork (interfaces: InterfacesClient,LoadBalancersClient,LoadBalancerBackendAddressPoolsClient,PrivateEndpointsClient,PrivateLinkServicesClient,PublicIPAddressesClient,SecurityGroupsClient,SubnetsClient) +// Source: github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armnetwork (interfaces: InterfacesClient,LoadBalancersClient,LoadBalancerBackendAddressPoolsClient,PrivateEndpointsClient,PrivateLinkServicesClient,PublicIPAddressesClient,SecurityGroupsClient,SubnetsClient,UsagesClient,VirtualNetworksClient) // // Generated by this command: // -// mockgen -destination=../../../../util/mocks/azureclient/azuresdk/armnetwork/armnetwork.go github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armnetwork InterfacesClient,LoadBalancersClient,LoadBalancerBackendAddressPoolsClient,PrivateEndpointsClient,PrivateLinkServicesClient,PublicIPAddressesClient,SecurityGroupsClient,SubnetsClient +// mockgen -destination=../../../../util/mocks/azureclient/azuresdk/armnetwork/armnetwork.go github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armnetwork InterfacesClient,LoadBalancersClient,LoadBalancerBackendAddressPoolsClient,PrivateEndpointsClient,PrivateLinkServicesClient,PublicIPAddressesClient,SecurityGroupsClient,SubnetsClient,UsagesClient,VirtualNetworksClient // // Package mock_armnetwork is a generated GoMock package. @@ -563,3 +563,79 @@ func (mr *MockSubnetsClientMockRecorder) List(arg0, arg1, arg2, arg3 any) *gomoc mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockSubnetsClient)(nil).List), arg0, arg1, arg2, arg3) } + +// MockUsagesClient is a mock of UsagesClient interface. +type MockUsagesClient struct { + ctrl *gomock.Controller + recorder *MockUsagesClientMockRecorder +} + +// MockUsagesClientMockRecorder is the mock recorder for MockUsagesClient. +type MockUsagesClientMockRecorder struct { + mock *MockUsagesClient +} + +// NewMockUsagesClient creates a new mock instance. +func NewMockUsagesClient(ctrl *gomock.Controller) *MockUsagesClient { + mock := &MockUsagesClient{ctrl: ctrl} + mock.recorder = &MockUsagesClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUsagesClient) EXPECT() *MockUsagesClientMockRecorder { + return m.recorder +} + +// List mocks base method. +func (m *MockUsagesClient) List(arg0 context.Context, arg1 string, arg2 *armnetwork.UsagesClientListOptions) ([]*armnetwork.Usage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1, arg2) + ret0, _ := ret[0].([]*armnetwork.Usage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockUsagesClientMockRecorder) List(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockUsagesClient)(nil).List), arg0, arg1, arg2) +} + +// MockVirtualNetworksClient is a mock of VirtualNetworksClient interface. +type MockVirtualNetworksClient struct { + ctrl *gomock.Controller + recorder *MockVirtualNetworksClientMockRecorder +} + +// MockVirtualNetworksClientMockRecorder is the mock recorder for MockVirtualNetworksClient. +type MockVirtualNetworksClientMockRecorder struct { + mock *MockVirtualNetworksClient +} + +// NewMockVirtualNetworksClient creates a new mock instance. +func NewMockVirtualNetworksClient(ctrl *gomock.Controller) *MockVirtualNetworksClient { + mock := &MockVirtualNetworksClient{ctrl: ctrl} + mock.recorder = &MockVirtualNetworksClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockVirtualNetworksClient) EXPECT() *MockVirtualNetworksClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockVirtualNetworksClient) Get(arg0 context.Context, arg1, arg2 string, arg3 *armnetwork.VirtualNetworksClientGetOptions) (armnetwork.VirtualNetworksClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(armnetwork.VirtualNetworksClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockVirtualNetworksClientMockRecorder) Get(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockVirtualNetworksClient)(nil).Get), arg0, arg1, arg2, arg3) +} diff --git a/pkg/util/mocks/azureclient/mgmt/network/network.go b/pkg/util/mocks/azureclient/mgmt/network/network.go index 5515a6ca80a..1e29b566556 100644 --- a/pkg/util/mocks/azureclient/mgmt/network/network.go +++ b/pkg/util/mocks/azureclient/mgmt/network/network.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network (interfaces: InterfacesClient,LoadBalancersClient,PrivateEndpointsClient,PublicIPAddressesClient,LoadBalancerBackendAddressPoolsClient,RouteTablesClient,SubnetsClient,VirtualNetworksClient,UsageClient,FlowLogsClient) +// Source: github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network (interfaces: InterfacesClient,LoadBalancersClient,PrivateEndpointsClient,PublicIPAddressesClient,LoadBalancerBackendAddressPoolsClient,RouteTablesClient,SubnetsClient,VirtualNetworksClient,FlowLogsClient) // // Generated by this command: // -// mockgen -destination=../../../../util/mocks/azureclient/mgmt/network/network.go github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network InterfacesClient,LoadBalancersClient,PrivateEndpointsClient,PublicIPAddressesClient,LoadBalancerBackendAddressPoolsClient,RouteTablesClient,SubnetsClient,VirtualNetworksClient,UsageClient,FlowLogsClient +// mockgen -destination=../../../../util/mocks/azureclient/mgmt/network/network.go github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network InterfacesClient,LoadBalancersClient,PrivateEndpointsClient,PublicIPAddressesClient,LoadBalancerBackendAddressPoolsClient,RouteTablesClient,SubnetsClient,VirtualNetworksClient,FlowLogsClient // // Package mock_network is a generated GoMock package. @@ -491,44 +491,6 @@ func (mr *MockVirtualNetworksClientMockRecorder) Get(arg0, arg1, arg2, arg3 any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockVirtualNetworksClient)(nil).Get), arg0, arg1, arg2, arg3) } -// MockUsageClient is a mock of UsageClient interface. -type MockUsageClient struct { - ctrl *gomock.Controller - recorder *MockUsageClientMockRecorder -} - -// MockUsageClientMockRecorder is the mock recorder for MockUsageClient. -type MockUsageClientMockRecorder struct { - mock *MockUsageClient -} - -// NewMockUsageClient creates a new mock instance. -func NewMockUsageClient(ctrl *gomock.Controller) *MockUsageClient { - mock := &MockUsageClient{ctrl: ctrl} - mock.recorder = &MockUsageClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUsageClient) EXPECT() *MockUsageClientMockRecorder { - return m.recorder -} - -// List mocks base method. -func (m *MockUsageClient) List(arg0 context.Context, arg1 string) ([]network.Usage, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", arg0, arg1) - ret0, _ := ret[0].([]network.Usage) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// List indicates an expected call of List. -func (mr *MockUsageClientMockRecorder) List(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockUsageClient)(nil).List), arg0, arg1) -} - // MockFlowLogsClient is a mock of FlowLogsClient interface. type MockFlowLogsClient struct { ctrl *gomock.Controller diff --git a/pkg/validate/dynamic/cache_vnet.go b/pkg/validate/dynamic/cache_vnet.go index ab6cce91b9f..1c3d89e1d9d 100644 --- a/pkg/validate/dynamic/cache_vnet.go +++ b/pkg/validate/dynamic/cache_vnet.go @@ -6,35 +6,35 @@ package dynamic import ( "context" - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" ) type virtualNetworksGetClient interface { - Get(context.Context, string, string, string) (mgmtnetwork.VirtualNetwork, error) + Get(ctx context.Context, resourceGroupName string, virtualNetworkName string, options *armnetwork.VirtualNetworksClientGetOptions) (vnet armnetwork.VirtualNetworksClientGetResponse, err error) } type virtualNetworksCacheKey struct { resourceGroupName string virtualNetworkName string - expand string + options *armnetwork.VirtualNetworksClientGetOptions } type virtualNetworksCache struct { c virtualNetworksGetClient - m map[virtualNetworksCacheKey]mgmtnetwork.VirtualNetwork + m map[virtualNetworksCacheKey]armnetwork.VirtualNetworksClientGetResponse } -func (vnc *virtualNetworksCache) Get(ctx context.Context, resourceGroupName string, virtualNetworkName string, expand string) (mgmtnetwork.VirtualNetwork, error) { - if _, ok := vnc.m[virtualNetworksCacheKey{resourceGroupName, virtualNetworkName, expand}]; !ok { - vnet, err := vnc.c.Get(ctx, resourceGroupName, virtualNetworkName, expand) +func (vnc *virtualNetworksCache) Get(ctx context.Context, resourceGroupName string, virtualNetworkName string, options *armnetwork.VirtualNetworksClientGetOptions) (armnetwork.VirtualNetworksClientGetResponse, error) { + if _, ok := vnc.m[virtualNetworksCacheKey{resourceGroupName, virtualNetworkName, options}]; !ok { + vnet, err := vnc.c.Get(ctx, resourceGroupName, virtualNetworkName, options) if err != nil { return vnet, err } - vnc.m[virtualNetworksCacheKey{resourceGroupName, virtualNetworkName, expand}] = vnet + vnc.m[virtualNetworksCacheKey{resourceGroupName, virtualNetworkName, options}] = vnet } - return vnc.m[virtualNetworksCacheKey{resourceGroupName, virtualNetworkName, expand}], nil + return vnc.m[virtualNetworksCacheKey{resourceGroupName, virtualNetworkName, options}], nil } // newVirtualNetworksCache returns a new virtualNetworksCache. It knows nothing @@ -43,6 +43,6 @@ func (vnc *virtualNetworksCache) Get(ctx context.Context, resourceGroupName stri func newVirtualNetworksCache(c virtualNetworksGetClient) virtualNetworksGetClient { return &virtualNetworksCache{ c: c, - m: map[virtualNetworksCacheKey]mgmtnetwork.VirtualNetwork{}, + m: map[virtualNetworksCacheKey]armnetwork.VirtualNetworksClientGetResponse{}, } } diff --git a/pkg/validate/dynamic/dynamic.go b/pkg/validate/dynamic/dynamic.go index fb5f2069cdb..e9ba5de0c4a 100644 --- a/pkg/validate/dynamic/dynamic.go +++ b/pkg/validate/dynamic/dynamic.go @@ -13,7 +13,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" + sdknetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure" "github.com/apparentlymart/go-cidr/cidr" @@ -26,8 +26,8 @@ import ( "github.com/Azure/ARO-RP/pkg/util/azureclient" "github.com/Azure/ARO-RP/pkg/util/azureclient/authz/remotepdp" "github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armauthorization" + "github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armnetwork" "github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/compute" - "github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network" "github.com/Azure/ARO-RP/pkg/util/stringutils" "github.com/Azure/ARO-RP/pkg/util/token" ) @@ -96,8 +96,8 @@ type dynamic struct { virtualNetworks virtualNetworksGetClient diskEncryptionSets compute.DiskEncryptionSetsClient resourceSkusClient compute.ResourceSkusClient - spNetworkUsage network.UsageClient - loadBalancerBackendAddressPoolsClient network.LoadBalancerBackendAddressPoolsClient + spNetworkUsage armnetwork.UsagesClient + loadBalancerBackendAddressPoolsClient armnetwork.LoadBalancerBackendAddressPoolsClient pdpClient remotepdp.RemotePDPClient } @@ -119,7 +119,24 @@ func NewValidator( authorizerType AuthorizerType, cred azcore.TokenCredential, pdpClient remotepdp.RemotePDPClient, -) Dynamic { +) (Dynamic, error) { + options := azEnv.ArmClientOptions() + + usagesClient, err := armnetwork.NewUsagesClient(subscriptionID, cred, options) + if err != nil { + return nil, err + } + + virtualNetworksClient, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, options) + if err != nil { + return nil, err + } + + loadBalancerBackendAddressPoolsClient, err := armnetwork.NewLoadBalancerBackendAddressPoolsClient(subscriptionID, cred, options) + if err != nil { + return nil, err + } + return &dynamic{ log: log, appID: appID, @@ -128,15 +145,13 @@ func NewValidator( azEnv: azEnv, checkAccessSubjectInfoCred: cred, - spNetworkUsage: network.NewUsageClient(azEnv, subscriptionID, authorizer), - virtualNetworks: newVirtualNetworksCache( - network.NewVirtualNetworksClient(azEnv, subscriptionID, authorizer), - ), + spNetworkUsage: usagesClient, + virtualNetworks: newVirtualNetworksCache(virtualNetworksClient), diskEncryptionSets: compute.NewDiskEncryptionSetsClient(azEnv, subscriptionID, authorizer), resourceSkusClient: compute.NewResourceSkusClient(azEnv, subscriptionID, authorizer), pdpClient: pdpClient, - loadBalancerBackendAddressPoolsClient: network.NewLoadBalancerBackendAddressPoolsClient(azEnv, subscriptionID, authorizer), - } + loadBalancerBackendAddressPoolsClient: loadBalancerBackendAddressPoolsClient, + }, nil } func NewServicePrincipalValidator( @@ -294,12 +309,12 @@ func (dv *dynamic) validateRouteTablePermissions(ctx context.Context, s Subnet) return err } - vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, "") + vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, nil) if err != nil { return err } - rtID, err := getRouteTableID(&vnet, s.ID) + rtID, err := getRouteTableID(&vnet.VirtualNetwork, s.ID) if err != nil || rtID == "" { // error or no route table return err } @@ -366,12 +381,12 @@ func (dv *dynamic) validateNatGatewayPermissions(ctx context.Context, s Subnet) return err } - vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, "") + vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, nil) if err != nil { return err } - ngID, err := getNatGatewayID(&vnet, s.ID) + ngID, err := getNatGatewayID(&vnet.VirtualNetwork, s.ID) if err != nil { return err } @@ -544,27 +559,27 @@ func (dv *dynamic) validateCIDRRanges(ctx context.Context, subnets []Subnet, add return err } - vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, "") + vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, nil) if err != nil { return err } - s, err := findSubnet(&vnet, s.ID) + s, err := findSubnet(&vnet.VirtualNetwork, s.ID) if err != nil { return err } // Validate the CIDR of AddressPrefix or AddressPrefixes, whichever is defined - if s.AddressPrefix == nil { - for _, address := range *s.AddressPrefixes { - _, net, err := net.ParseCIDR(address) + if s.Properties.AddressPrefix == nil { + for _, address := range s.Properties.AddressPrefixes { + _, net, err := net.ParseCIDR(*address) if err != nil { return err } CIDRArray = append(CIDRArray, net) } } else { - _, net, err := net.ParseCIDR(*s.AddressPrefix) + _, net, err := net.ParseCIDR(*s.Properties.AddressPrefix) if err != nil { return err } @@ -597,7 +612,7 @@ func (dv *dynamic) validateCIDRRanges(ctx context.Context, subnets []Subnet, add func (dv *dynamic) validateVnetLocation(ctx context.Context, vnetr azure.Resource, location string) error { dv.log.Print("validateVnetLocation") - vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, "") + vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, nil) if err != nil { return err } @@ -616,11 +631,11 @@ func (dv *dynamic) validateVnetLocation(ctx context.Context, vnetr azure.Resourc return nil } -func (dv *dynamic) createSubnetMapByID(ctx context.Context, subnets []Subnet) (map[string]*mgmtnetwork.Subnet, error) { +func (dv *dynamic) createSubnetMapByID(ctx context.Context, subnets []Subnet) (map[string]*sdknetwork.Subnet, error) { if len(subnets) == 0 { return nil, fmt.Errorf("no subnets found") } - subnetByID := make(map[string]*mgmtnetwork.Subnet) + subnetByID := make(map[string]*sdknetwork.Subnet) for _, s := range subnets { vnetID, _, err := apisubnet.Split(s.ID) @@ -631,12 +646,12 @@ func (dv *dynamic) createSubnetMapByID(ctx context.Context, subnets []Subnet) (m if err != nil { return nil, err } - vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, "") + vnet, err := dv.virtualNetworks.Get(ctx, vnetr.ResourceGroup, vnetr.ResourceName, nil) if err != nil { return nil, err } - ss, err := findSubnet(&vnet, s.ID) + ss, err := findSubnet(&vnet.VirtualNetwork, s.ID) if err != nil { return nil, err } @@ -659,7 +674,7 @@ func (dv *dynamic) createSubnetMapByID(ctx context.Context, subnets []Subnet) (m // checkPreconfiguredNSG checks whether all the subnets have an NSG attached. // when the PreconfigureNSG feature flag is on and not all subnets are attached, // it returns an error. -func (dv *dynamic) checkPreconfiguredNSG(subnetByID map[string]*mgmtnetwork.Subnet) error { +func (dv *dynamic) checkPreconfiguredNSG(subnetByID map[string]*sdknetwork.Subnet) error { var attached int for _, subnet := range subnetByID { if subnetHasNSGAttached(subnet) { @@ -709,7 +724,7 @@ func (dv *dynamic) ValidateSubnets(ctx context.Context, oc *api.OpenShiftCluster if err != nil { return err } - if !isTheSameNSG(*ss.SubnetPropertiesFormat.NetworkSecurityGroup.ID, expectedNsgID) { + if !isTheSameNSG(*ss.Properties.NetworkSecurityGroup.ID, expectedNsgID) { return api.NewCloudError( http.StatusBadRequest, api.CloudErrorCodeInvalidLinkedVNet, @@ -723,7 +738,7 @@ func (dv *dynamic) ValidateSubnets(ctx context.Context, oc *api.OpenShiftCluster } if oc.Properties.NetworkProfile.PreconfiguredNSG == api.PreconfiguredNSGDisabled { if !subnetHasNSGAttached(ss) || - !isTheSameNSG(*ss.SubnetPropertiesFormat.NetworkSecurityGroup.ID, nsgID) { + !isTheSameNSG(*ss.Properties.NetworkSecurityGroup.ID, nsgID) { return api.NewCloudError( http.StatusBadRequest, api.CloudErrorCodeInvalidLinkedVNet, @@ -746,8 +761,7 @@ func (dv *dynamic) ValidateSubnets(ctx context.Context, oc *api.OpenShiftCluster } } - if ss.SubnetPropertiesFormat == nil || - ss.SubnetPropertiesFormat.ProvisioningState != mgmtnetwork.Succeeded { + if ss.Properties == nil || ss.Properties.ProvisioningState == nil || *ss.Properties.ProvisioningState != sdknetwork.ProvisioningStateSucceeded { return api.NewCloudError( http.StatusBadRequest, api.CloudErrorCodeInvalidLinkedVNet, @@ -758,14 +772,14 @@ func (dv *dynamic) ValidateSubnets(ctx context.Context, oc *api.OpenShiftCluster } // Handle both addressPrefix & addressPrefixes - if ss.AddressPrefix == nil { - for _, address := range *ss.AddressPrefixes { - if err = validateSubnetSize(s, address); err != nil { + if ss.Properties.AddressPrefix == nil { + for _, address := range ss.Properties.AddressPrefixes { + if err = validateSubnetSize(s, *address); err != nil { return err } } } else { - if err = validateSubnetSize(s, *ss.AddressPrefix); err != nil { + if err = validateSubnetSize(s, *ss.Properties.AddressPrefix); err != nil { return err } } @@ -808,7 +822,7 @@ func (dv *dynamic) ValidatePreConfiguredNSGs(ctx context.Context, oc *api.OpenSh } for _, s := range subnetByID { - nsgID := s.NetworkSecurityGroup.ID + nsgID := s.Properties.NetworkSecurityGroup.ID if nsgID == nil || *nsgID == "" { return api.NewCloudError( http.StatusBadRequest, @@ -886,41 +900,41 @@ func isTheSameNSG(found, inDB string) bool { return strings.EqualFold(found, inDB) } -func subnetHasNSGAttached(subnet *mgmtnetwork.Subnet) bool { - return subnet.NetworkSecurityGroup != nil && subnet.NetworkSecurityGroup.ID != nil +func subnetHasNSGAttached(subnet *sdknetwork.Subnet) bool { + return subnet.Properties.NetworkSecurityGroup != nil && subnet.Properties.NetworkSecurityGroup.ID != nil } -func getRouteTableID(vnet *mgmtnetwork.VirtualNetwork, subnetID string) (string, error) { +func getRouteTableID(vnet *sdknetwork.VirtualNetwork, subnetID string) (string, error) { s, err := findSubnet(vnet, subnetID) if err != nil { return "", err } - if s == nil || s.RouteTable == nil { + if s == nil || s.Properties.RouteTable == nil { return "", nil } - return *s.RouteTable.ID, nil + return *s.Properties.RouteTable.ID, nil } -func getNatGatewayID(vnet *mgmtnetwork.VirtualNetwork, subnetID string) (string, error) { +func getNatGatewayID(vnet *sdknetwork.VirtualNetwork, subnetID string) (string, error) { s, err := findSubnet(vnet, subnetID) if err != nil { return "", err } - if s == nil || s.NatGateway == nil { + if s == nil || s.Properties.NatGateway == nil { return "", nil } - return *s.NatGateway.ID, nil + return *s.Properties.NatGateway.ID, nil } -func findSubnet(vnet *mgmtnetwork.VirtualNetwork, subnetID string) (*mgmtnetwork.Subnet, error) { - if vnet.Subnets != nil { - for _, s := range *vnet.Subnets { +func findSubnet(vnet *sdknetwork.VirtualNetwork, subnetID string) (*sdknetwork.Subnet, error) { + if vnet.Properties.Subnets != nil { + for _, s := range vnet.Properties.Subnets { if strings.EqualFold(*s.ID, subnetID) { - return &s, nil + return s, nil } } } diff --git a/pkg/validate/dynamic/dynamic_test.go b/pkg/validate/dynamic/dynamic_test.go index 0efcf266211..ce245530d69 100644 --- a/pkg/validate/dynamic/dynamic_test.go +++ b/pkg/validate/dynamic/dynamic_test.go @@ -11,18 +11,18 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" + sdknetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" "github.com/Azure/go-autorest/autorest/azure" - "github.com/Azure/go-autorest/autorest/to" "github.com/sirupsen/logrus" "go.uber.org/mock/gomock" + "k8s.io/utils/ptr" "github.com/Azure/ARO-RP/pkg/api" "github.com/Azure/ARO-RP/pkg/util/azureclient" "github.com/Azure/ARO-RP/pkg/util/azureclient/authz/remotepdp" mock_remotepdp "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/authz/remotepdp" + mock_armnetwork "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/azuresdk/armnetwork" mock_azcore "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/azuresdk/azcore" - mock_network "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/network" "github.com/Azure/ARO-RP/pkg/util/uuid" utilerror "github.com/Azure/ARO-RP/test/util/error" ) @@ -51,7 +51,7 @@ var ( func TestGetRouteTableID(t *testing.T) { for _, tt := range []struct { name string - modifyVnet func(*mgmtnetwork.VirtualNetwork) + modifyVnet func(*sdknetwork.VirtualNetwork) wantErr string }{ { @@ -59,14 +59,14 @@ func TestGetRouteTableID(t *testing.T) { }, { name: "pass: no route table", - modifyVnet: func(vnet *mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].RouteTable = nil + modifyVnet: func(vnet *sdknetwork.VirtualNetwork) { + vnet.Properties.Subnets[0].Properties.RouteTable = nil }, }, { name: "fail: can't find subnet", - modifyVnet: func(vnet *mgmtnetwork.VirtualNetwork) { - vnet.Subnets = nil + modifyVnet: func(vnet *sdknetwork.VirtualNetwork) { + vnet.Properties.Subnets = nil }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + masterSubnet + "' could not be found.", }, @@ -75,14 +75,14 @@ func TestGetRouteTableID(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - vnet := &mgmtnetwork.VirtualNetwork{ + vnet := &sdknetwork.VirtualNetwork{ ID: &vnetID, - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - RouteTable: &mgmtnetwork.RouteTable{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + RouteTable: &sdknetwork.RouteTable{ ID: &masterRtID, }, }, @@ -105,7 +105,7 @@ func TestGetRouteTableID(t *testing.T) { func TestGetNatGatewayID(t *testing.T) { for _, tt := range []struct { name string - modifyVnet func(*mgmtnetwork.VirtualNetwork) + modifyVnet func(*sdknetwork.VirtualNetwork) wantErr string }{ { @@ -113,14 +113,14 @@ func TestGetNatGatewayID(t *testing.T) { }, { name: "pass: no nat gateway", - modifyVnet: func(vnet *mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NatGateway = nil + modifyVnet: func(vnet *sdknetwork.VirtualNetwork) { + vnet.Properties.Subnets[0].Properties.NatGateway = nil }, }, { name: "fail: can't find subnet", - modifyVnet: func(vnet *mgmtnetwork.VirtualNetwork) { - vnet.Subnets = nil + modifyVnet: func(vnet *sdknetwork.VirtualNetwork) { + vnet.Properties.Subnets = nil }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + masterSubnet + "' could not be found.", }, @@ -129,14 +129,14 @@ func TestGetNatGatewayID(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - vnet := &mgmtnetwork.VirtualNetwork{ + vnet := &sdknetwork.VirtualNetwork{ ID: &vnetID, - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - NatGateway: &mgmtnetwork.SubResource{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + NatGateway: &sdknetwork.SubResource{ ID: &masterNgID, }, }, @@ -162,17 +162,17 @@ func TestValidateCIDRRanges(t *testing.T) { for _, tt := range []struct { name string modifyOC func(*api.OpenShiftCluster) - vnetMocks func(*mock_network.MockVirtualNetworksClient, mgmtnetwork.VirtualNetwork) + vnetMocks func(*mock_armnetwork.MockVirtualNetworksClient, sdknetwork.VirtualNetworksClientGetResponse) wantErr string }{ { name: "pass", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, @@ -181,12 +181,12 @@ func TestValidateCIDRRanges(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.NetworkProfile.ServiceCIDR = "10.0.0.0/24" }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: : The provided CIDRs must not overlap: '10.0.0.0/24 overlaps with 10.0.0.0/24'.", @@ -219,27 +219,27 @@ func TestValidateCIDRRanges(t *testing.T) { }, } - vnets := []mgmtnetwork.VirtualNetwork{ + vnets := []sdknetwork.VirtualNetwork{ { ID: &vnetID, - Location: to.StringPtr("eastus"), - Name: to.StringPtr("VNET With AddressPrefix"), - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Location: ptr.To("eastus"), + Name: ptr.To("VNET With AddressPrefix"), + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefix: to.StringPtr("10.0.0.0/24"), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefix: ptr.To("10.0.0.0/24"), + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &masterNSGv1, }, }, }, { ID: &workerSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefix: to.StringPtr("10.0.1.0/24"), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefix: ptr.To("10.0.1.0/24"), + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &workerNSGv1, }, }, @@ -249,24 +249,24 @@ func TestValidateCIDRRanges(t *testing.T) { }, { ID: &vnetID, - Location: to.StringPtr("eastus"), - Name: to.StringPtr("VNET With AddressPrefixes"), - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Location: ptr.To("eastus"), + Name: ptr.To("VNET With AddressPrefixes"), + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefixes: to.StringSlicePtr([]string{"10.0.0.0/24"}), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefixes: []*string{ptr.To("10.0.0.0/24")}, + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &masterNSGv1, }, }, }, { ID: &workerSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefixes: to.StringSlicePtr([]string{"10.0.1.0/24"}), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefixes: []*string{ptr.To("10.0.1.0/24")}, + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &workerNSGv1, }, }, @@ -281,9 +281,9 @@ func TestValidateCIDRRanges(t *testing.T) { } for _, vnet := range vnets { - vnetClient := mock_network.NewMockVirtualNetworksClient(controller) + vnetClient := mock_armnetwork.NewMockVirtualNetworksClient(controller) if tt.vnetMocks != nil { - tt.vnetMocks(vnetClient, vnet) + tt.vnetMocks(vnetClient, sdknetwork.VirtualNetworksClientGetResponse{VirtualNetwork: vnet}) } dv := &dynamic{ @@ -324,14 +324,16 @@ func TestValidateVnetLocation(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - vnet := mgmtnetwork.VirtualNetwork{ - ID: to.StringPtr(vnetID), - Location: to.StringPtr(tt.location), + vnet := sdknetwork.VirtualNetworksClientGetResponse{ + VirtualNetwork: sdknetwork.VirtualNetwork{ + ID: ptr.To(vnetID), + Location: ptr.To(tt.location), + }, } - vnetClient := mock_network.NewMockVirtualNetworksClient(controller) + vnetClient := mock_armnetwork.NewMockVirtualNetworksClient(controller) vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) dv := &dynamic{ @@ -356,14 +358,14 @@ func TestValidateSubnets(t *testing.T) { for _, tt := range []struct { name string modifyOC func(*api.OpenShiftCluster) - vnetMocks func(*mock_network.MockVirtualNetworksClient, mgmtnetwork.VirtualNetwork) + vnetMocks func(*mock_armnetwork.MockVirtualNetworksClient, sdknetwork.VirtualNetworksClientGetResponse) wantErr string }{ { name: "pass", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, @@ -374,9 +376,9 @@ func TestValidateSubnets(t *testing.T) { SubnetID: masterSubnet, } }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, @@ -385,38 +387,38 @@ func TestValidateSubnets(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.ProvisioningState = api.ProvisioningStateCreating }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, { name: "fail: subnet does not exist on vnet", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - vnet.Subnets = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + masterSubnet + "' could not be found.", }, { name: "pass: subnet provisioning state is succeeded", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].ProvisioningState = mgmtnetwork.Succeeded + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.ProvisioningState = ptr.To(sdknetwork.ProvisioningStateSucceeded) vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, { name: "fail: subnet provisioning state is not succeeded", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].ProvisioningState = mgmtnetwork.Failed + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.ProvisioningState = ptr.To(sdknetwork.ProvisioningStateFailed) vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: properties.masterProfile.subnetId: The provided subnet '" + masterSubnet + "' is not in a Succeeded state", @@ -427,9 +429,9 @@ func TestValidateSubnets(t *testing.T) { oc.Properties.ProvisioningState = api.ProvisioningStateCreating oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGDisabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, @@ -439,10 +441,10 @@ func TestValidateSubnets(t *testing.T) { oc.Properties.ProvisioningState = api.ProvisioningStateCreating oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGDisabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup.ID = to.StringPtr("not-the-correct-nsg") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup.ID = ptr.To("not-the-correct-nsg") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: properties.masterProfile.subnetId: The provided subnet '" + masterSubnet + "' is invalid: must not have a network security group attached.", @@ -453,10 +455,10 @@ func TestValidateSubnets(t *testing.T) { oc.Properties.ProvisioningState = api.ProvisioningStateCreating oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGEnabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup.ID = to.StringPtr("attached") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup.ID = ptr.To("attached") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, @@ -465,19 +467,19 @@ func TestValidateSubnets(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.ArchitectureVersion = 9001 }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "unknown architecture version 9001", }, { name: "fail: nsg id doesn't match expected", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup.ID = to.StringPtr("not matching") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup.ID = ptr.To("not matching") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, modifyOC: func(oc *api.OpenShiftCluster) { @@ -488,10 +490,10 @@ func TestValidateSubnets(t *testing.T) { }, { name: "pass: byonsg doesn't check if nsg ids are matched after creating", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup.ID = to.StringPtr("don't care what it is") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup.ID = ptr.To("don't care what it is") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, modifyOC: func(oc *api.OpenShiftCluster) { @@ -501,10 +503,10 @@ func TestValidateSubnets(t *testing.T) { }, { name: "fail: no nsg attached during update", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup.ID = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup.ID = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, modifyOC: func(oc *api.OpenShiftCluster) { @@ -515,10 +517,10 @@ func TestValidateSubnets(t *testing.T) { }, { name: "fail: byonsg requires an nsg to be attached during update", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NetworkSecurityGroup.ID = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NetworkSecurityGroup.ID = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, modifyOC: func(oc *api.OpenShiftCluster) { @@ -529,20 +531,20 @@ func TestValidateSubnets(t *testing.T) { }, { name: "fail: invalid subnet CIDR", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].AddressPrefix = to.StringPtr("not-valid") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.AddressPrefix = ptr.To("not-valid") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "invalid CIDR address: not-valid", }, { name: "fail: too small subnet CIDR", - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].AddressPrefix = to.StringPtr("10.0.0.0/28") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.AddressPrefix = ptr.To("10.0.0.0/28") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: properties.masterProfile.subnetId: The provided subnet '" + masterSubnet + "' is invalid: must be /27 or larger.", @@ -559,18 +561,18 @@ func TestValidateSubnets(t *testing.T) { }, }, } - vnet := mgmtnetwork.VirtualNetwork{ + vnet := sdknetwork.VirtualNetwork{ ID: &vnetID, - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefix: to.StringPtr("10.0.0.0/24"), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefix: ptr.To("10.0.0.0/24"), + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &masterNSGv1, }, - ProvisioningState: mgmtnetwork.Succeeded, + ProvisioningState: ptr.To(sdknetwork.ProvisioningStateSucceeded), }, }, }, @@ -580,9 +582,9 @@ func TestValidateSubnets(t *testing.T) { if tt.modifyOC != nil { tt.modifyOC(oc) } - vnetClient := mock_network.NewMockVirtualNetworksClient(controller) + vnetClient := mock_armnetwork.NewMockVirtualNetworksClient(controller) if tt.vnetMocks != nil { - tt.vnetMocks(vnetClient, vnet) + tt.vnetMocks(vnetClient, sdknetwork.VirtualNetworksClientGetResponse{VirtualNetwork: vnet}) } dv := &dynamic{ virtualNetworks: vnetClient, @@ -849,7 +851,7 @@ func TestValidateVnetPermissions(t *testing.T) { dv := &dynamic{ azEnv: &azureclient.PublicCloud, - appID: to.StringPtr("fff51942-b1f9-4119-9453-aaa922259eb7"), + appID: ptr.To("fff51942-b1f9-4119-9453-aaa922259eb7"), authorizerType: AuthorizerClusterServicePrincipal, log: logrus.NewEntry(logrus.StandardLogger()), pdpClient: pdpClient, @@ -930,15 +932,15 @@ func TestValidateRouteTablesPermissions(t *testing.T) { platformIdentities []api.PlatformWorkloadIdentity platformIdentityMap map[string][]string pdpClientMocks func(*mock_azcore.MockTokenCredential, *mock_remotepdp.MockRemotePDPClient, context.CancelFunc) - vnetMocks func(*mock_network.MockVirtualNetworksClient, mgmtnetwork.VirtualNetwork) + vnetMocks func(*mock_armnetwork.MockVirtualNetworksClient, sdknetwork.VirtualNetworksClientGetResponse) wantErr string }{ { name: "fail: failed to get vnet", subnet: Subnet{ID: masterSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, errors.New("failed to get vnet")) }, wantErr: "failed to get vnet", @@ -946,10 +948,10 @@ func TestValidateRouteTablesPermissions(t *testing.T) { { name: "fail: master subnet doesn't exist", subnet: Subnet{ID: masterSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - vnet.Subnets = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + masterSubnet + "' could not be found.", @@ -957,10 +959,10 @@ func TestValidateRouteTablesPermissions(t *testing.T) { { name: "fail: worker subnet ID doesn't exist", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[1].ID = to.StringPtr("not valid") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[1].ID = ptr.To("not valid") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + workerSubnet + "' could not be found.", @@ -968,20 +970,20 @@ func TestValidateRouteTablesPermissions(t *testing.T) { { name: "pass: no route table to check", subnet: Subnet{ID: masterSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].RouteTable = nil - (*vnet.Subnets)[1].RouteTable = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.RouteTable = nil + vnet.Properties.Subnets[1].Properties.RouteTable = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, { name: "fail: permissions don't exist", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1002,9 +1004,9 @@ func TestValidateRouteTablesPermissions(t *testing.T) { platformIdentityMap: map[string][]string{ "Dummy": platformIdentity1SubnetActions, }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1021,9 +1023,9 @@ func TestValidateRouteTablesPermissions(t *testing.T) { { name: "fail: CheckAccessV2 doesn't return all the entries", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1040,9 +1042,9 @@ func TestValidateRouteTablesPermissions(t *testing.T) { { name: "pass", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1059,9 +1061,9 @@ func TestValidateRouteTablesPermissions(t *testing.T) { platformIdentityMap: map[string][]string{ "Dummy": platformIdentity1SubnetActions, }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1078,9 +1080,9 @@ func TestValidateRouteTablesPermissions(t *testing.T) { platformIdentityMap: map[string][]string{ "Dummy": platformIdentity1SubnetActionsNoIntersect, }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1099,24 +1101,24 @@ func TestValidateRouteTablesPermissions(t *testing.T) { pdpClient := mock_remotepdp.NewMockRemotePDPClient(controller) - vnetClient := mock_network.NewMockVirtualNetworksClient(controller) + vnetClient := mock_armnetwork.NewMockVirtualNetworksClient(controller) - vnet := &mgmtnetwork.VirtualNetwork{ + vnet := &sdknetwork.VirtualNetwork{ ID: &vnetID, - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - RouteTable: &mgmtnetwork.RouteTable{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + RouteTable: &sdknetwork.RouteTable{ ID: &masterRtID, }, }, }, { ID: &workerSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - RouteTable: &mgmtnetwork.RouteTable{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + RouteTable: &sdknetwork.RouteTable{ ID: &workerRtID, }, }, @@ -1126,7 +1128,7 @@ func TestValidateRouteTablesPermissions(t *testing.T) { } dv := &dynamic{ - appID: to.StringPtr("fff51942-b1f9-4119-9453-aaa922259eb7"), + appID: ptr.To("fff51942-b1f9-4119-9453-aaa922259eb7"), azEnv: &azureclient.PublicCloud, authorizerType: AuthorizerClusterServicePrincipal, log: logrus.NewEntry(logrus.StandardLogger()), @@ -1140,7 +1142,7 @@ func TestValidateRouteTablesPermissions(t *testing.T) { } if tt.vnetMocks != nil { - tt.vnetMocks(vnetClient, *vnet) + tt.vnetMocks(vnetClient, sdknetwork.VirtualNetworksClientGetResponse{VirtualNetwork: *vnet}) } if tt.platformIdentities != nil { @@ -1212,15 +1214,15 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { platformIdentities []api.PlatformWorkloadIdentity platformIdentityMap map[string][]string pdpClientMocks func(*mock_azcore.MockTokenCredential, *mock_remotepdp.MockRemotePDPClient, context.CancelFunc) - vnetMocks func(*mock_network.MockVirtualNetworksClient, mgmtnetwork.VirtualNetwork) + vnetMocks func(*mock_armnetwork.MockVirtualNetworksClient, sdknetwork.VirtualNetworksClientGetResponse) wantErr string }{ { name: "fail: failed to get vnet", subnet: Subnet{ID: masterSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, errors.New("failed to get vnet")) }, wantErr: "failed to get vnet", @@ -1228,10 +1230,10 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { { name: "fail: master subnet doesn't exist", subnet: Subnet{ID: masterSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - vnet.Subnets = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + masterSubnet + "' could not be found.", @@ -1239,10 +1241,10 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { { name: "fail: worker subnet ID doesn't exist", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[1].ID = to.StringPtr("not valid") + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[1].ID = ptr.To("not valid") vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, wantErr: "400: InvalidLinkedVNet: : The provided subnet '" + workerSubnet + "' could not be found.", @@ -1250,9 +1252,9 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { { name: "fail: permissions don't exist", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1274,9 +1276,9 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { platformIdentityMap: map[string][]string{ "Dummy": platformIdentity1SubnetActions, }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1294,9 +1296,9 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { { name: "fail: CheckAccessV2 doesn't return all permissions", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1314,9 +1316,9 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { { name: "pass", subnet: Subnet{ID: workerSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1333,9 +1335,9 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { platformIdentityMap: map[string][]string{ "Dummy": platformIdentity1SubnetActions, }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1352,9 +1354,9 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { platformIdentityMap: map[string][]string{ "Dummy": platformIdentity1SubnetActionsNoIntersect, }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, pdpClientMocks: func(tokenCred *mock_azcore.MockTokenCredential, pdpClient *mock_remotepdp.MockRemotePDPClient, cancel context.CancelFunc) { @@ -1364,11 +1366,11 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { { name: "pass: no nat gateway to check", subnet: Subnet{ID: masterSubnet}, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { - (*vnet.Subnets)[0].NatGateway = nil - (*vnet.Subnets)[1].NatGateway = nil + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { + vnet.Properties.Subnets[0].Properties.NatGateway = nil + vnet.Properties.Subnets[1].Properties.NatGateway = nil vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). Return(vnet, nil) }, }, @@ -1380,28 +1382,28 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() - vnetClient := mock_network.NewMockVirtualNetworksClient(controller) + vnetClient := mock_armnetwork.NewMockVirtualNetworksClient(controller) tokenCred := mock_azcore.NewMockTokenCredential(controller) pdpClient := mock_remotepdp.NewMockRemotePDPClient(controller) - vnet := &mgmtnetwork.VirtualNetwork{ + vnet := &sdknetwork.VirtualNetwork{ ID: &vnetID, - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - NatGateway: &mgmtnetwork.SubResource{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + NatGateway: &sdknetwork.SubResource{ ID: &masterNgID, }, }, }, { ID: &workerSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - NatGateway: &mgmtnetwork.SubResource{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + NatGateway: &sdknetwork.SubResource{ ID: &workerNgID, }, }, @@ -1411,7 +1413,7 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { } dv := &dynamic{ - appID: to.StringPtr("fff51942-b1f9-4119-9453-aaa922259eb7"), + appID: ptr.To("fff51942-b1f9-4119-9453-aaa922259eb7"), azEnv: &azureclient.PublicCloud, authorizerType: AuthorizerClusterServicePrincipal, log: logrus.NewEntry(logrus.StandardLogger()), @@ -1425,7 +1427,7 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { } if tt.vnetMocks != nil { - tt.vnetMocks(vnetClient, *vnet) + tt.vnetMocks(vnetClient, sdknetwork.VirtualNetworksClientGetResponse{VirtualNetwork: *vnet}) } if tt.platformIdentities != nil { @@ -1441,32 +1443,32 @@ func TestValidateNatGatewaysPermissions(t *testing.T) { } func TestCheckPreconfiguredNSG(t *testing.T) { - subnetWithNSG := &mgmtnetwork.Subnet{ - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + subnetWithNSG := &sdknetwork.Subnet{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &masterNSGv1, }, }, } - subnetWithoutNSG := &mgmtnetwork.Subnet{ - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{}, + subnetWithoutNSG := &sdknetwork.Subnet{ + Properties: &sdknetwork.SubnetPropertiesFormat{}, } for _, tt := range []struct { name string - subnetByID map[string]*mgmtnetwork.Subnet + subnetByID map[string]*sdknetwork.Subnet wantErr string }{ { name: "pass: all subnets are attached", - subnetByID: map[string]*mgmtnetwork.Subnet{ + subnetByID: map[string]*sdknetwork.Subnet{ "A": subnetWithNSG, "B": subnetWithNSG, }, }, { name: "fail: no subnets are attached", - subnetByID: map[string]*mgmtnetwork.Subnet{ + subnetByID: map[string]*sdknetwork.Subnet{ "A": subnetWithoutNSG, "B": subnetWithoutNSG, }, @@ -1474,7 +1476,7 @@ func TestCheckPreconfiguredNSG(t *testing.T) { }, { name: "fail: parts of the subnets are attached", - subnetByID: map[string]*mgmtnetwork.Subnet{ + subnetByID: map[string]*sdknetwork.Subnet{ "A": subnetWithNSG, "B": subnetWithoutNSG, "C": subnetWithNSG, @@ -1518,7 +1520,7 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { platformIdentities []api.PlatformWorkloadIdentity platformIdentityMap map[string][]string checkAccessMocks func(context.CancelFunc, *mock_remotepdp.MockRemotePDPClient, *mock_azcore.MockTokenCredential) - vnetMocks func(*mock_network.MockVirtualNetworksClient, mgmtnetwork.VirtualNetwork) + vnetMocks func(*mock_armnetwork.MockVirtualNetworksClient, sdknetwork.VirtualNetworksClientGetResponse) wantErr string }{ { @@ -1532,9 +1534,9 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGEnabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). AnyTimes(). Return(vnet, nil) }, @@ -1560,9 +1562,9 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGEnabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). AnyTimes(). Return(vnet, nil) }, @@ -1592,9 +1594,9 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGEnabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). AnyTimes(). Return(vnet, nil) }, @@ -1613,9 +1615,9 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGEnabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). AnyTimes(). Return(vnet, nil) }, @@ -1638,9 +1640,9 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { modifyOC: func(oc *api.OpenShiftCluster) { oc.Properties.NetworkProfile.PreconfiguredNSG = api.PreconfiguredNSGEnabled }, - vnetMocks: func(vnetClient *mock_network.MockVirtualNetworksClient, vnet mgmtnetwork.VirtualNetwork) { + vnetMocks: func(vnetClient *mock_armnetwork.MockVirtualNetworksClient, vnet sdknetwork.VirtualNetworksClientGetResponse) { vnetClient.EXPECT(). - Get(gomock.Any(), resourceGroupName, vnetName, ""). + Get(gomock.Any(), resourceGroupName, vnetName, nil). AnyTimes(). Return(vnet, nil) }, @@ -1673,28 +1675,28 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { }, }, } - vnet := mgmtnetwork.VirtualNetwork{ + vnet := sdknetwork.VirtualNetwork{ ID: &vnetID, - VirtualNetworkPropertiesFormat: &mgmtnetwork.VirtualNetworkPropertiesFormat{ - Subnets: &[]mgmtnetwork.Subnet{ + Properties: &sdknetwork.VirtualNetworkPropertiesFormat{ + Subnets: []*sdknetwork.Subnet{ { ID: &masterSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefix: to.StringPtr("10.0.0.0/24"), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefix: ptr.To("10.0.0.0/24"), + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &masterNSGv1, }, - ProvisioningState: mgmtnetwork.Succeeded, + ProvisioningState: ptr.To(sdknetwork.ProvisioningStateSucceeded), }, }, { ID: &workerSubnet, - SubnetPropertiesFormat: &mgmtnetwork.SubnetPropertiesFormat{ - AddressPrefix: to.StringPtr("10.0.1.0/24"), - NetworkSecurityGroup: &mgmtnetwork.SecurityGroup{ + Properties: &sdknetwork.SubnetPropertiesFormat{ + AddressPrefix: ptr.To("10.0.1.0/24"), + NetworkSecurityGroup: &sdknetwork.SecurityGroup{ ID: &workerNSGv1, }, - ProvisioningState: mgmtnetwork.Succeeded, + ProvisioningState: ptr.To(sdknetwork.ProvisioningStateSucceeded), }, }, }, @@ -1705,10 +1707,10 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { tt.modifyOC(oc) } - vnetClient := mock_network.NewMockVirtualNetworksClient(controller) + vnetClient := mock_armnetwork.NewMockVirtualNetworksClient(controller) if tt.vnetMocks != nil { - tt.vnetMocks(vnetClient, vnet) + tt.vnetMocks(vnetClient, sdknetwork.VirtualNetworksClientGetResponse{VirtualNetwork: vnet}) } tokenCred := mock_azcore.NewMockTokenCredential(controller) @@ -1720,7 +1722,7 @@ func TestValidatePreconfiguredNSGPermissions(t *testing.T) { dv := &dynamic{ azEnv: &azureclient.PublicCloud, - appID: to.StringPtr("fff51942-b1f9-4119-9453-aaa922259eb7"), + appID: ptr.To("fff51942-b1f9-4119-9453-aaa922259eb7"), authorizerType: AuthorizerClusterServicePrincipal, virtualNetworks: vnetClient, pdpClient: pdpClient, diff --git a/pkg/validate/dynamic/loadbalancerprofile.go b/pkg/validate/dynamic/loadbalancerprofile.go index 11a3e386254..23a36912020 100644 --- a/pkg/validate/dynamic/loadbalancerprofile.go +++ b/pkg/validate/dynamic/loadbalancerprofile.go @@ -50,7 +50,7 @@ func (dv *dynamic) validatePublicIPQuota(ctx context.Context, oc *api.OpenShiftC } } - netUsages, err := dv.spNetworkUsage.List(ctx, oc.Location) + netUsages, err := dv.spNetworkUsage.List(ctx, oc.Location, nil) if err != nil { return err } @@ -75,12 +75,12 @@ func (dv *dynamic) validateOBRuleV4FrontendPorts(ctx context.Context, oc *api.Op loadBalancerName := oc.Properties.InfraID backendAddressPoolName := oc.Properties.InfraID - backendPools, err := dv.loadBalancerBackendAddressPoolsClient.Get(ctx, rgName, loadBalancerName, backendAddressPoolName) + backendPools, err := dv.loadBalancerBackendAddressPoolsClient.Get(ctx, rgName, loadBalancerName, backendAddressPoolName, nil) if err != nil { return err } - totalBackendInstances := len(*backendPools.BackendAddressPoolPropertiesFormat.BackendIPConfigurations) + totalBackendInstances := len(backendPools.Properties.BackendIPConfigurations) // TODO: update once allocatedOutboundPorts is implemented allocatedOutboundPorts := 1024 var desiredNumIPs int diff --git a/pkg/validate/dynamic/loadbalancerprofile_test.go b/pkg/validate/dynamic/loadbalancerprofile_test.go index b344bb829d6..e05ea6ff929 100644 --- a/pkg/validate/dynamic/loadbalancerprofile_test.go +++ b/pkg/validate/dynamic/loadbalancerprofile_test.go @@ -8,13 +8,14 @@ import ( "strconv" "testing" - mgmtnetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2020-08-01/network" + sdknetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" "github.com/Azure/go-autorest/autorest/to" "github.com/sirupsen/logrus" "go.uber.org/mock/gomock" + "k8s.io/utils/ptr" "github.com/Azure/ARO-RP/pkg/api" - mock_network "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/network" + mock_armnetwork "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/azuresdk/armnetwork" utilerror "github.com/Azure/ARO-RP/test/util/error" ) @@ -27,7 +28,7 @@ func TestValidateLoadBalancerProfile(t *testing.T) { for _, tt := range []struct { name string oc *api.OpenShiftCluster - mocks func(spNetworkUsage *mock_network.MockUsageClient, loadBalancerBackendAddressPoolsClient *mock_network.MockLoadBalancerBackendAddressPoolsClient) + mocks func(spNetworkUsage *mock_armnetwork.MockUsagesClient, loadBalancerBackendAddressPoolsClient *mock_armnetwork.MockLoadBalancerBackendAddressPoolsClient) wantErr string }{ { @@ -71,13 +72,13 @@ func TestValidateLoadBalancerProfile(t *testing.T) { }, }, }, - mocks: func(spNetworkUsage *mock_network.MockUsageClient, - loadBalancerBackendAddressPoolsClient *mock_network.MockLoadBalancerBackendAddressPoolsClient) { + mocks: func(spNetworkUsage *mock_armnetwork.MockUsagesClient, + loadBalancerBackendAddressPoolsClient *mock_armnetwork.MockLoadBalancerBackendAddressPoolsClient) { spNetworkUsage.EXPECT(). - List(gomock.Any(), location). - Return([]mgmtnetwork.Usage{ + List(gomock.Any(), location, nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(4), @@ -85,10 +86,12 @@ func TestValidateLoadBalancerProfile(t *testing.T) { }, }, nil) loadBalancerBackendAddressPoolsClient.EXPECT(). - Get(gomock.Any(), clusterRGName, infraID, infraID). - Return(mgmtnetwork.BackendAddressPool{ - BackendAddressPoolPropertiesFormat: &mgmtnetwork.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: getFakeBackendIPConfigs(6), + Get(gomock.Any(), clusterRGName, infraID, infraID, nil). + Return(sdknetwork.LoadBalancerBackendAddressPoolsClientGetResponse{ + BackendAddressPool: sdknetwork.BackendAddressPool{ + Properties: &sdknetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: getFakeBackendIPConfigs(6), + }, }, }, nil) }, @@ -101,8 +104,8 @@ func TestValidateLoadBalancerProfile(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - loadBalancerBackendAddressPoolsClient := mock_network.NewMockLoadBalancerBackendAddressPoolsClient(controller) - networkUsageClient := mock_network.NewMockUsageClient(controller) + loadBalancerBackendAddressPoolsClient := mock_armnetwork.NewMockLoadBalancerBackendAddressPoolsClient(controller) + networkUsageClient := mock_armnetwork.NewMockUsagesClient(controller) if tt.mocks != nil { tt.mocks(networkUsageClient, loadBalancerBackendAddressPoolsClient) @@ -126,7 +129,7 @@ func TestValidatePublicIPQuota(t *testing.T) { for _, tt := range []struct { name string oc *api.OpenShiftCluster - mocks func(spNetworkUsage *mock_network.MockUsageClient) + mocks func(spNetworkUsage *mock_armnetwork.MockUsagesClient) wantErr string }{ { @@ -158,12 +161,12 @@ func TestValidatePublicIPQuota(t *testing.T) { }, }, }, - mocks: func(spNetworkUsage *mock_network.MockUsageClient) { + mocks: func(spNetworkUsage *mock_armnetwork.MockUsagesClient) { spNetworkUsage.EXPECT(). - List(gomock.Any(), location). - Return([]mgmtnetwork.Usage{ + List(gomock.Any(), location, nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(4), @@ -201,12 +204,12 @@ func TestValidatePublicIPQuota(t *testing.T) { }, }, }, - mocks: func(spNetworkUsage *mock_network.MockUsageClient) { + mocks: func(spNetworkUsage *mock_armnetwork.MockUsagesClient) { spNetworkUsage.EXPECT(). - List(gomock.Any(), location). - Return([]mgmtnetwork.Usage{ + List(gomock.Any(), location, nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(8), @@ -240,12 +243,12 @@ func TestValidatePublicIPQuota(t *testing.T) { }, }, }, - mocks: func(spNetworkUsage *mock_network.MockUsageClient) { + mocks: func(spNetworkUsage *mock_armnetwork.MockUsagesClient) { spNetworkUsage.EXPECT(). - List(gomock.Any(), location). - Return([]mgmtnetwork.Usage{ + List(gomock.Any(), location, nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(4), @@ -278,12 +281,12 @@ func TestValidatePublicIPQuota(t *testing.T) { }, }, }, - mocks: func(spNetworkUsage *mock_network.MockUsageClient) { + mocks: func(spNetworkUsage *mock_armnetwork.MockUsagesClient) { spNetworkUsage.EXPECT(). - List(gomock.Any(), location). - Return([]mgmtnetwork.Usage{ + List(gomock.Any(), location, nil). + Return([]*sdknetwork.Usage{ { - Name: &mgmtnetwork.UsageName{ + Name: &sdknetwork.UsageName{ Value: to.StringPtr("PublicIPAddresses"), }, CurrentValue: to.Int64Ptr(8), @@ -301,7 +304,7 @@ func TestValidatePublicIPQuota(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - networkUsageClient := mock_network.NewMockUsageClient(controller) + networkUsageClient := mock_armnetwork.NewMockUsagesClient(controller) if tt.mocks != nil { tt.mocks(networkUsageClient) @@ -327,7 +330,7 @@ func TestValidateOBRuleV4FrontendPorts(t *testing.T) { for _, tt := range []struct { name string oc *api.OpenShiftCluster - mocks func(loadBalancerBackendAddressPoolsClient *mock_network.MockLoadBalancerBackendAddressPoolsClient) + mocks func(loadBalancerBackendAddressPoolsClient *mock_armnetwork.MockLoadBalancerBackendAddressPoolsClient) wantErr string }{ { @@ -358,12 +361,14 @@ func TestValidateOBRuleV4FrontendPorts(t *testing.T) { }, }, mocks: func( - loadBalancerBackendAddressPoolsClient *mock_network.MockLoadBalancerBackendAddressPoolsClient) { + loadBalancerBackendAddressPoolsClient *mock_armnetwork.MockLoadBalancerBackendAddressPoolsClient) { loadBalancerBackendAddressPoolsClient.EXPECT(). - Get(gomock.Any(), clusterRGName, infraID, infraID). - Return(mgmtnetwork.BackendAddressPool{ - BackendAddressPoolPropertiesFormat: &mgmtnetwork.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: getFakeBackendIPConfigs(62), + Get(gomock.Any(), clusterRGName, infraID, infraID, nil). + Return(sdknetwork.LoadBalancerBackendAddressPoolsClientGetResponse{ + BackendAddressPool: sdknetwork.BackendAddressPool{ + Properties: &sdknetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: getFakeBackendIPConfigs(62), + }, }, }, nil) }, @@ -397,12 +402,14 @@ func TestValidateOBRuleV4FrontendPorts(t *testing.T) { }, wantErr: "400: InvalidParameter: properties.networkProfile.loadBalancerProfile: Insufficient frontend ports to support the backend instance count. Total frontend ports: 63992, Required frontend ports: 64512, Total backend instances: 63", mocks: func( - loadBalancerBackendAddressPoolsClient *mock_network.MockLoadBalancerBackendAddressPoolsClient) { + loadBalancerBackendAddressPoolsClient *mock_armnetwork.MockLoadBalancerBackendAddressPoolsClient) { loadBalancerBackendAddressPoolsClient.EXPECT(). - Get(gomock.Any(), clusterRGName, infraID, infraID). - Return(mgmtnetwork.BackendAddressPool{ - BackendAddressPoolPropertiesFormat: &mgmtnetwork.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: getFakeBackendIPConfigs(63), + Get(gomock.Any(), clusterRGName, infraID, infraID, nil). + Return(sdknetwork.LoadBalancerBackendAddressPoolsClientGetResponse{ + BackendAddressPool: sdknetwork.BackendAddressPool{ + Properties: &sdknetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: getFakeBackendIPConfigs(63), + }, }, }, nil) }, @@ -415,7 +422,7 @@ func TestValidateOBRuleV4FrontendPorts(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - loadBalancerBackendAddressPoolsClient := mock_network.NewMockLoadBalancerBackendAddressPoolsClient(controller) + loadBalancerBackendAddressPoolsClient := mock_armnetwork.NewMockLoadBalancerBackendAddressPoolsClient(controller) if tt.mocks != nil { tt.mocks(loadBalancerBackendAddressPoolsClient) @@ -432,11 +439,11 @@ func TestValidateOBRuleV4FrontendPorts(t *testing.T) { } } -func getFakeBackendIPConfigs(ipConfigCount int) *[]mgmtnetwork.InterfaceIPConfiguration { - ipConfigs := []mgmtnetwork.InterfaceIPConfiguration{} +func getFakeBackendIPConfigs(ipConfigCount int) []*sdknetwork.InterfaceIPConfiguration { + var ipConfigs []*sdknetwork.InterfaceIPConfiguration for i := 0; i < ipConfigCount; i++ { ipConfigName := "ip-" + strconv.Itoa(i) - ipConfigs = append(ipConfigs, mgmtnetwork.InterfaceIPConfiguration{Name: to.StringPtr(ipConfigName)}) + ipConfigs = append(ipConfigs, &sdknetwork.InterfaceIPConfiguration{Name: ptr.To(ipConfigName)}) } - return &ipConfigs + return ipConfigs } diff --git a/pkg/validate/openshiftcluster_validatedynamic.go b/pkg/validate/openshiftcluster_validatedynamic.go index fa9156110d5..081237834e3 100644 --- a/pkg/validate/openshiftcluster_validatedynamic.go +++ b/pkg/validate/openshiftcluster_validatedynamic.go @@ -149,7 +149,7 @@ func (dv *openShiftClusterDynamicValidator) Dynamic(ctx context.Context) error { return err } - spDynamic = dynamic.NewValidator( + spDynamic, err = dynamic.NewValidator( dv.log, dv.env, dv.env.Environment(), @@ -160,13 +160,16 @@ func (dv *openShiftClusterDynamicValidator) Dynamic(ctx context.Context) error { spClientCred, pdpClient, ) + if err != nil { + return err + } err = spDynamic.ValidateServicePrincipal(ctx, spClientCred) if err != nil { return err } } else { // PlatformWorkloadIdentity and ClusterMSIIdentity Validation - spDynamic = dynamic.NewValidator( + spDynamic, err = dynamic.NewValidator( dv.log, dv.env, dv.env.Environment(), @@ -177,6 +180,9 @@ func (dv *openShiftClusterDynamicValidator) Dynamic(ctx context.Context) error { nil, pdpClient, ) + if err != nil { + return err + } err = spDynamic.ValidatePlatformWorkloadIdentityProfile(ctx, dv.oc, dv.platformWorkloadIdentityRolesByVersion.GetPlatformWorkloadIdentityRolesByRoleName(), dv.roleDefinitions) if err != nil { return err @@ -220,7 +226,7 @@ func (dv *openShiftClusterDynamicValidator) Dynamic(ctx context.Context) error { } // FP validation - fpDynamic := dynamic.NewValidator( + fpDynamic, err := dynamic.NewValidator( dv.log, dv.env, dv.env.Environment(), @@ -231,6 +237,9 @@ func (dv *openShiftClusterDynamicValidator) Dynamic(ctx context.Context) error { fpClientCred, pdpClient, ) + if err != nil { + return err + } err = fpDynamic.ValidateVnet( ctx,