Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions tests/test_commands/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def workspace(vcr_instance, test_data):
workspace_name = f"{display_name}.Workspace"
workspace_path = f"/{workspace_name}"

mkdir(workspace_path, params=[f"capacityName={test_data.capacity.name}"])
mkdir(workspace_path, params=[
f"capacityName={test_data.capacity.name}"])
yield EntityMetadata(display_name, workspace_name, workspace_path)
rm(workspace_path)

Expand All @@ -292,7 +293,8 @@ def _create_item(
generated_name = custom_name
else:
# Use the test's specific recording file
generated_name = generate_random_string(vcr_instance, cassette_name)
generated_name = generate_random_string(
vcr_instance, cassette_name)

item_name = f"{generated_name}.{type}"
item_path = cli_path_join(path, item_name)
Expand All @@ -318,7 +320,8 @@ def _create_item(
@pytest.fixture
def folder_factory(vcr_instance, cassette_name, workspace):
# Keep track of all folders created during this test
current_config = state_config.get_config(fab_constant.FAB_FOLDER_LISTING_ENABLED)
current_config = state_config.get_config(
fab_constant.FAB_FOLDER_LISTING_ENABLED)
state_config.set_config(fab_constant.FAB_FOLDER_LISTING_ENABLED, "true")
created_folders = []

Expand Down Expand Up @@ -348,7 +351,8 @@ def _create_folder(
for metadata in reversed(created_folders):
rm(metadata.full_path)

state_config.set_config(fab_constant.FAB_FOLDER_LISTING_ENABLED, current_config)
state_config.set_config(
fab_constant.FAB_FOLDER_LISTING_ENABLED, current_config)


@pytest.fixture
Expand All @@ -374,7 +378,8 @@ def _create_virtual_item(
"""
generated_name = generate_random_string(vcr_instance, cassette_name)
virtual_item_name = f"{generated_name}.{str(VICMap[type])}"
virtual_item_path = cli_path_join(workspace_path, str(type), virtual_item_name)
virtual_item_path = cli_path_join(
workspace_path, str(type), virtual_item_name)

match type:

Expand Down Expand Up @@ -427,7 +432,8 @@ def _create_virtual_item(
mkdir(virtual_item_path, params)

# Build the metadata for the created resource
metadata = EntityMetadata(generated_name, virtual_item_name, virtual_item_path)
metadata = EntityMetadata(
generated_name, virtual_item_name, virtual_item_path)
if should_clean:
created_virtual_items.append(metadata)
return metadata
Expand Down Expand Up @@ -457,10 +463,12 @@ def _create_workspace(special_character=None):
workspace_name = f"{generated_name}.Workspace"
workspace_path = f"/{workspace_name}"

mkdir(workspace_path, params=[f"capacityName={test_data.capacity.name}"])
mkdir(workspace_path, params=[
f"capacityName={test_data.capacity.name}"])

# Build the metadata for the created resource
metadata = EntityMetadata(generated_name, workspace_name, workspace_path)
metadata = EntityMetadata(
generated_name, workspace_name, workspace_path)
created_workspaces.append(metadata)
return metadata

Expand All @@ -473,7 +481,10 @@ def _create_workspace(special_character=None):

@pytest.fixture
def virtual_workspace_item_factory(
vcr_instance, cassette_name, test_data: StaticTestData
vcr_instance,
cassette_name,
test_data: StaticTestData,
vcr_mode,
):
# Keep track of all workspaces created during this test
created_virtual_workspace_items = []
Expand Down Expand Up @@ -506,13 +517,16 @@ def _create_virtual_workspace_item(type: VirtualWorkspaceType):
metadata = EntityMetadata(
generated_name, virtual_workspace_name, virtual_workspace_item_path
)
metadata.type = type
created_virtual_workspace_items.append(metadata)
return metadata

yield _create_virtual_workspace_item

# Teardown: remove everything we created during the test
for metadata in created_virtual_workspace_items:
if vcr_mode == "none" and metadata.type == VirtualWorkspaceType.CAPACITY:
continue
rm(metadata.full_path)


Expand Down Expand Up @@ -565,7 +579,8 @@ def delete_cassette_if_record_mode_all(vcr_instance, cassette_name):
:param cassette_name: The name of the cassette file.
"""
if vcr_instance.record_mode == "all":
cassette_path = os.path.join(vcr_instance.cassette_library_dir, cassette_name)
cassette_path = os.path.join(
vcr_instance.cassette_library_dir, cassette_name)
if os.path.exists(cassette_path):
os.remove(cassette_path)

Expand Down Expand Up @@ -673,7 +688,8 @@ def setup_config_values_for_capacity(test_data: StaticTestData):
fab_default_az_location = state_config.get_config(
fab_constant.FAB_DEFAULT_AZ_LOCATION
)
fab_default_az_admin = state_config.get_config(fab_constant.FAB_DEFAULT_AZ_ADMIN)
fab_default_az_admin = state_config.get_config(
fab_constant.FAB_DEFAULT_AZ_ADMIN)

# Setup new values
state_config.set_config(
Expand All @@ -687,7 +703,8 @@ def setup_config_values_for_capacity(test_data: StaticTestData):
state_config.set_config(
fab_constant.FAB_DEFAULT_AZ_LOCATION, test_data.azure_location
)
state_config.set_config(fab_constant.FAB_DEFAULT_AZ_ADMIN, test_data.admin.upn)
state_config.set_config(
fab_constant.FAB_DEFAULT_AZ_ADMIN, test_data.admin.upn)

yield

Expand All @@ -701,7 +718,8 @@ def setup_config_values_for_capacity(test_data: StaticTestData):
state_config.set_config(
fab_constant.FAB_DEFAULT_AZ_LOCATION, fab_default_az_location
)
state_config.set_config(fab_constant.FAB_DEFAULT_AZ_ADMIN, fab_default_az_admin)
state_config.set_config(
fab_constant.FAB_DEFAULT_AZ_ADMIN, fab_default_az_admin)


# endregion
23 changes: 20 additions & 3 deletions tests/test_commands/data/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from typing import Any


class User:
def __init__(self, user_data: dict[str, str]):
# Expecting keys: "id", "upn"
Expand Down Expand Up @@ -56,10 +59,13 @@ def name(self) -> str:


class EntityMetadata:
def __init__(self, display_name: str, name: str, full_path: str):
def __init__(
self, display_name: str, name: str, full_path: str, type: Any = None
):
self._display_name = display_name
self._name = name
self._full_path = full_path
self._type = type

@property
def display_name(self) -> str:
Expand All @@ -73,11 +79,20 @@ def name(self) -> str:
def full_path(self) -> str:
return self._full_path

@property
def type(self) -> Any:
return self._type

# This setter is required for the mv command
@full_path.setter
def full_path(self, new_path):
self._full_path = new_path

# This setter is required for cleanup during test teardown
@type.setter
def type(self, value: Any):
self._type = value


class SQLServer:
def __init__(self, sql_server_data: dict[str, str]):
Expand Down Expand Up @@ -123,16 +138,18 @@ def username(self) -> str:
def password(self) -> str:
return self._password


class OnPremisesGatewayDetails:
def __init__(self, gateway_data: dict[str, str]):
# Expecting keys: "id", "encrypted_credentials"
self._id = gateway_data.get("id") or ""
self._encrypted_credentials = gateway_data.get("encrypted_credentials") or ""
self._encrypted_credentials = gateway_data.get(
"encrypted_credentials") or ""

@property
def id(self) -> str:
return self._id

@property
def encrypted_credentials(self) -> str:
return self._encrypted_credentials
return self._encrypted_credentials
Loading
Loading