Skip to content

Commit

Permalink
{Spring} Extract get_test_cmd (#7494)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiec-msft committed Apr 11, 2024
1 parent 0307898 commit 26020a3
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 213 deletions.
5 changes: 5 additions & 0 deletions src/spring/azext_spring/tests/latest/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -----------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -----------------------------------------------------------------------------
24 changes: 24 additions & 0 deletions src/spring/azext_spring/tests/latest/common/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -----------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -----------------------------------------------------------------------------

from azure.cli.core import AzCommandsLoader
from azure.cli.core.commands import AzCliCommand
from azure.cli.core.mock import DummyCli

try:
import unittest.mock as mock
except ImportError:
from unittest import mock


def get_test_cmd():
cli_ctx = DummyCli()
cli_ctx.data['subscription_id'] = '00000000-0000-0000-0000-000000000000'
loader = AzCommandsLoader(cli_ctx, resource_type='Microsoft.AppPlatform')
cmd = AzCliCommand(loader, 'test', None)
cmd.command_kwargs = {'resource_type': 'Microsoft.AppPlatform'}
cmd.cli_ctx = cli_ctx
return cmd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.cli.core.mock import DummyCli
from azure.cli.core.commands import AzCliCommand

from ..common.test_utils import get_test_cmd
from ....managed_components.managed_component import get_component
from ....managed_components.validators_managed_component import (validate_component_logs,
validate_component_list,
Expand Down Expand Up @@ -56,26 +57,16 @@
]


def _get_test_cmd():
cli_ctx = DummyCli()
cli_ctx.data['subscription_id'] = '00000000-0000-0000-0000-000000000000'
loader = AzCommandsLoader(cli_ctx, resource_type='Microsoft.AppPlatform')
cmd = AzCliCommand(loader, 'test', None)
cmd.command_kwargs = {'resource_type': 'Microsoft.AppPlatform'}
cmd.cli_ctx = cli_ctx
return cmd


class TestValidateComponentList(unittest.TestCase):
@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
def test_tier(self, is_enterprise_tier_mock):
is_enterprise_tier_mock.return_value = False

with self.assertRaises(NotSupportedPricingTierError):
validate_component_list(_get_test_cmd(), Namespace(resource_group="group", service="service"))
validate_component_list(get_test_cmd(), Namespace(resource_group="group", service="service"))

is_enterprise_tier_mock.return_value = True
validate_component_list(_get_test_cmd(), Namespace(resource_group="group", service="service"))
validate_component_list(get_test_cmd(), Namespace(resource_group="group", service="service"))


class TestValidateComponentInstanceList(unittest.TestCase):
Expand All @@ -85,14 +76,14 @@ def test_component_name(self, is_enterprise_tier_mock):

for c in valid_component_names:
ns = Namespace(resource_group="group", service="service", component=c)
validate_instance_list(_get_test_cmd(), ns)
validate_instance_list(get_test_cmd(), ns)
component_obj = get_component(ns.component)
self.assertIsNotNone(component_obj)

for c in invalid_component_names:
with self.assertRaises(InvalidArgumentValueError) as context:
ns = Namespace(resource_group="group", service="service", component=c)
validate_instance_list(_get_test_cmd(), ns)
validate_instance_list(get_test_cmd(), ns)

self.assertTrue("is not supported" in str(context.exception))
self.assertTrue("Supported components are:" in str(context.exception))
Expand All @@ -102,11 +93,11 @@ def test_tier(self, is_enterprise_tier_mock):
is_enterprise_tier_mock.return_value = True

ns = Namespace(resource_group="group", service="service", component="application-configuration-service")
validate_instance_list(_get_test_cmd(), ns)
validate_instance_list(get_test_cmd(), ns)

is_enterprise_tier_mock.return_value = False
with self.assertRaises(NotSupportedPricingTierError):
validate_instance_list(_get_test_cmd(), ns)
validate_instance_list(get_test_cmd(), ns)


class TestValidateComponentLogs(unittest.TestCase):
Expand All @@ -122,7 +113,7 @@ def test_mutual_exclusive_param(self, is_enterprise_tier_mock):
)

with self.assertRaises(InvalidArgumentValueError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)

self.assertEquals("--all-instances cannot be set together with --instance/-i.", str(context.exception))

Expand All @@ -139,7 +130,7 @@ def test_required_param_missing(self, is_enterprise_tier_mock):
)

with self.assertRaises(InvalidArgumentValueError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)

self.assertEquals("When --name/-n is not set, --instance/-i is required.", str(context.exception))

Expand All @@ -160,7 +151,7 @@ def test_only_instance_name(self, is_enterprise_tier_mock):
)

with self.assertLogs('cli.azext_spring.managed_components.validators_managed_component', 'WARNING') as cm:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals(cm.output, ['WARNING:cli.azext_spring.managed_components.validators_managed_component:--instance/-i is specified without --name/-n, will try best effort get logs by instance.'])

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -179,7 +170,7 @@ def test_valid_component_name(self, is_enterprise_tier_mock):
since=None,
max_log_requests=5
)
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)

component_obj = get_component(ns.name)
self.assertIsNotNone(component_obj)
Expand Down Expand Up @@ -207,7 +198,7 @@ def test_valid_log_lines(self, is_enterprise_tier_mock):
since=None,
max_log_requests=5
)
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals(lines, ns.lines)

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -226,7 +217,7 @@ def test_log_lines_too_small(self, is_enterprise_tier_mock):
since=None
)
with self.assertRaises(InvalidArgumentValueError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals('--lines must be in the range [1,10000]', str(context.exception))

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -246,7 +237,7 @@ def test_log_lines_too_big(self, is_enterprise_tier_mock):
max_log_requests=5
)
with self.assertLogs('cli.azext_spring.log_stream.log_stream_validators', 'ERROR') as cm:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
expect_error_msgs = ['ERROR:cli.azext_spring.log_stream.log_stream_validators:'
'--lines can not be more than 10000, using 10000 instead']
self.assertEquals(expect_error_msgs, cm.output)
Expand Down Expand Up @@ -274,7 +265,7 @@ def test_valid_log_since(self, is_enterprise_tier_mock):
since=since,
max_log_requests=5
)
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
last = since[-1:]
since_in_seconds = int(since[:-1]) if last in ("hms") else int(since)
if last == 'h':
Expand Down Expand Up @@ -302,7 +293,7 @@ def test_invalid_log_since(self, is_enterprise_tier_mock):
since=since
)
with self.assertRaises(InvalidArgumentValueError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals("--since contains invalid characters", str(context.exception))

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -324,7 +315,7 @@ def test_log_since_too_big(self, is_enterprise_tier_mock):
since=since
)
with self.assertRaises(InvalidArgumentValueError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals("--since can not be more than 1h", str(context.exception))

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -346,7 +337,7 @@ def test_valid_log_limit(self, is_enterprise_tier_mock):
since='1h',
max_log_requests=5
)
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals(limit * 1024, ns.limit)

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -368,7 +359,7 @@ def test_negative_log_limit(self, is_enterprise_tier_mock):
since='1h'
)
with self.assertRaises(InvalidArgumentValueError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals('--limit must be in the range [1,2048]', str(context.exception))

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
Expand All @@ -391,7 +382,7 @@ def test_log_limit_too_big(self, is_enterprise_tier_mock):
max_log_requests=5,
)
with self.assertLogs('cli.azext_spring.log_stream.log_stream_validators', 'ERROR') as cm:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
error_msgs = ['ERROR:cli.azext_spring.log_stream.log_stream_validators:'
'--limit can not be more than 2048, using 2048 instead']
self.assertEquals(error_msgs, cm.output)
Expand All @@ -413,7 +404,7 @@ def test_tier(self, is_enterprise_tier_mock):
max_log_requests=5
)

validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)

@mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True)
def test_invalid_tier(self, is_enterprise_tier_mock):
Expand All @@ -432,5 +423,5 @@ def test_invalid_tier(self, is_enterprise_tier_mock):
)

with self.assertRaises(NotSupportedPricingTierError) as context:
validate_component_logs(_get_test_cmd(), ns)
validate_component_logs(get_test_cmd(), ns)
self.assertEquals("Only enterprise tier service instance is supported in this command.", str(context.exception))
19 changes: 5 additions & 14 deletions src/spring/azext_spring/tests/latest/test_asa_api_portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azure.cli.testsdk.preparers import (
RoleBasedServicePrincipalPreparer
)
from .common.test_utils import get_test_cmd
from .custom_preparers import SpringPreparer, SpringResourceGroupPreparer, SpringSubResourceWrapper
from .custom_dev_setting_constant import SpringTestEnvironmentEnum
from ...vendored_sdks.appplatform.v2024_01_01_preview import models
Expand Down Expand Up @@ -96,16 +97,6 @@ def test_api_portal(self, resource_group, spring, sp_name, sp_password):
])


def _get_test_cmd():
cli_ctx = DummyCli()
cli_ctx.data['subscription_id'] = '00000000-0000-0000-0000-000000000000'
loader = AzCommandsLoader(cli_ctx, resource_type='Microsoft.AppPlatform')
cmd = AzCliCommand(loader, 'test', None)
cmd.command_kwargs = {'resource_type': 'Microsoft.AppPlatform'}
cmd.cli_ctx = cli_ctx
return cmd


def _get_basic_mock_client(*_):
return mock.MagicMock()

Expand All @@ -123,7 +114,7 @@ def _execute(self, method, cmd, client, *kwargs):

def test_custom_domain(self):
client = _get_basic_mock_client()
api_portal_custom_domain_update(_get_test_cmd(),
api_portal_custom_domain_update(get_test_cmd(),
client,
'rg',
'asa',
Expand All @@ -145,7 +136,7 @@ def _get_cert(*_, **__):

client = _get_basic_mock_client()
client.certificates.get = _get_cert
api_portal_custom_domain_update(_get_test_cmd(),
api_portal_custom_domain_update(get_test_cmd(),
client,
'rg',
'asa',
Expand All @@ -166,7 +157,7 @@ def _get_cert(*_, **__):
client.certificates.get = _get_cert
self.assertRaises(RuntimeError,
api_portal_custom_domain_update,
_get_test_cmd(),
get_test_cmd(),
client,
'rg',
'asa',
Expand All @@ -175,7 +166,7 @@ def _get_cert(*_, **__):

def test_custom_domain_unbind(self):
client = _get_basic_mock_client()
api_portal_custom_domain_unbind(_get_test_cmd(),
api_portal_custom_domain_unbind(get_test_cmd(),
client,
'rg',
'asa',
Expand Down
Loading

0 comments on commit 26020a3

Please sign in to comment.