Skip to content

Commit

Permalink
Add tests and fux style
Browse files Browse the repository at this point in the history
  • Loading branch information
vthiebaut10 committed Jun 17, 2024
1 parent 872a28d commit 89e0b6c
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 63 deletions.
88 changes: 53 additions & 35 deletions src/ssh/azext_ssh/connectivity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# --------------------------------------------------------------------------------------------

import time
import stat
import os
import urllib.request
import json
import base64
import oras.client
Expand Down Expand Up @@ -228,14 +226,17 @@ def install_client_side_proxy(arc_proxy_folder):
install_dir = _get_proxy_install_dir(arc_proxy_folder)
proxy_name = _get_proxy_filename(client_operating_system, client_architecture)
install_location = os.path.join(install_dir, proxy_name)

# Only download new proxy if it doesn't exist already
if not os.path.isfile(install_location):
if not os.path.isdir(install_dir):
file_utils.create_directory(install_dir, f"Failed to create client proxy directory '{install_dir}'. ")
file_utils.create_directory(install_dir, f"Failed to create client proxy directory '{install_dir}'.")
# if directory exists, delete any older versions of the proxy
else:
older_version_location = _get_older_version_proxy_path(install_dir, client_operating_system, client_architecture)
older_version_location = _get_older_version_proxy_path(
install_dir,
client_operating_system,
client_architecture)
older_version_files = glob(older_version_location)
for f in older_version_files:
file_utils.delete_file(f, f"failed to delete older version file {f}", warning=True)
Expand All @@ -247,20 +248,20 @@ def install_client_side_proxy(arc_proxy_folder):


def _download_proxy_from_MCR(dest_dir, proxy_name, operating_system, architecture):
mar_target = f"{consts.CLIENT_PROXY_MCR_TARGET}/{operating_system.lower()}/{architecture}/ssh-proxy"
logger.debug(f"Downloading Arc Connectivity Proxy from {mar_target} in Microsoft Artifact Regristy.")
mar_target = f"{consts.CLIENT_PROXY_MCR_TARGET}/{operating_system.lower()}/{architecture}/ssh-proxy"
logger.debug("Downloading Arc Connectivity Proxy from %s in Microsoft Artifact Regristy.", mar_target)

client = oras.client.OrasClient()

t0 = time.time()

try:
response = client.pull(target=f"{mar_target}:{consts.CLIENT_PROXY_VERSION}", outdir=dest_dir)
except Exception as e:
raise azclierror.CLIInternalError("Failed to download Arc Connectivity proxy. Please try again.")

raise azclierror.CLIInternalError(
f"Failed to download Arc Connectivity proxy with error {str(e)}. Please try again.")

time_elapsed = time.time() - t0

proxy_data = {
'Context.Default.AzureCLI.SSHProxyDownloadTime': time_elapsed,
'Context.Default.AzureCLI.SSHProxyVersion': consts.CLIENT_PROXY_VERSION
Expand All @@ -274,20 +275,22 @@ def _download_proxy_from_MCR(dest_dir, proxy_name, operating_system, architectur

def _get_proxy_package_path_from_oras_response(pull_response):
if not isinstance(pull_response, list):
raise azclierror.CLIInternalError("Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

raise azclierror.CLIInternalError(
"Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

if len(pull_response) != 1:
for r in pull_response:
file_utils.delete_file(r, f"Failed to delete {r}. Please delete it manually.", True)
raise azclierror.CLIInternalError("Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

raise azclierror.CLIInternalError(
"Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

proxy_package_path = pull_response[0]

if not os.path.isfile(proxy_package_path):
raise azclierror.CLIInternalError("Unable to download Arc Connectivity Proxy. Please try again.")
logger.debug(f"Proxy package downloaded to {proxy_package_path}")

logger.debug("Proxy package downloaded to %s", proxy_package_path)

return proxy_package_path


Expand All @@ -297,23 +300,31 @@ def _extract_proxy_tar_files(proxy_package_path, install_dir, proxy_name):
for member in tar.getmembers():
if member.isfile():
filenames = member.name.split('/')

if len(filenames) != 2:
tar.close()
file_utils.delete_file(proxy_package_path, f"Failed to delete {proxy_package_path}. Please delete it manually.", True)
raise azclierror.CLIInternalError("Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

file_utils.delete_file(
proxy_package_path,
f"Failed to delete {proxy_package_path}. Please delete it manually.",
True)
raise azclierror.CLIInternalError(
"Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

member.name = filenames[1]

if member.name.startswith('sshproxy'):
member.name = proxy_name
elif member.name.lower() not in ['license.txt', 'thirdpartynotice.txt']:
tar.close()
file_utils.delete_file(proxy_package_path, f"Failed to delete {proxy_package_path}. Please delete it manually.", True)
raise azclierror.CLIInternalError("Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")
file_utils.delete_file(
proxy_package_path,
f"Failed to delete {proxy_package_path}. Please delete it manually.",
True)
raise azclierror.CLIInternalError(
"Attempt to download Arc Connectivity Proxy returned unnexpected result. Please try again.")

members.append(member)

tar.extractall(members=members, path=install_dir)


Expand All @@ -322,19 +333,28 @@ def _check_proxy_installation(install_dir, proxy_name):
if os.path.isfile(proxy_filepath):
print_styled_text((Style.SUCCESS, f"Successfuly installed SSH Connectivity Proxy file {proxy_filepath}"))
else:
raise azclierror.CLIInternalError(f"Failed to install required SSH Arc Connectivity Proxy. Couldn't find expected file {proxy_filepath}. Please try again.")

raise azclierror.CLIInternalError(
"Failed to install required SSH Arc Connectivity Proxy. "
f"Couldn't find expected file {proxy_filepath}. Please try again.")

license_files = ["License.txt", "ThirdPartyNotice.txt"]
for file in license_files:
file_location = os.path.join(install_dir, file)
if os.path.isfile(file_location):
print_styled_text((Style.SUCCESS, f"Successfuly installed SSH Connectivity Proxy License file {file_location}"))
print_styled_text(
(Style.SUCCESS,
f"Successfuly installed SSH Connectivity Proxy License file {file_location}"))
else:
logger.warning(f"Failed to download Arc Connectivity Proxy license file {file}. Clouldn't find expected file {file_location}. This won't affect your connection.")

logger.warning(
"Failed to download Arc Connectivity Proxy license file %s. Clouldn't find expected file %s. "
"This won't affect your connection.", file, file_location)


def _get_proxy_filename(operating_system, architecture):
proxy_filename = f"sshProxy_{operating_system.lower()}_{architecture}_{consts.CLIENT_PROXY_VERSION.replace('.', '_')}"
if operating_system.lower() == 'darwin' and architecture == '386':
raise azclierror.BadRequestError("Unsupported Darwin OS with 386 architecture.")
proxy_filename = \
f"sshProxy_{operating_system.lower()}_{architecture}_{consts.CLIENT_PROXY_VERSION.replace('.', '_')}"
if operating_system.lower() == 'windows':
proxy_filename += '.exe'
return proxy_filename
Expand Down Expand Up @@ -368,7 +388,7 @@ def _get_client_architeture():
raise azclierror.BadRequestError("Couldn't identify the platform architecture.")
else:
raise azclierror.BadRequestError(f"Unsuported architecture: {machine} is not currently supported")

return architecture


Expand All @@ -381,5 +401,3 @@ def _get_client_operating_system():
if operating_system.lower() not in ('linux', 'darwin', 'windows'):
raise azclierror.BadRequestError(f"Unsuported OS: {operating_system} platform is not currently supported")
return operating_system


5 changes: 1 addition & 4 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_
ssh_client_folder=None, delete_credentials=False, resource_type=None, ssh_proxy_folder=None,
winrdp=False, yes_without_prompt=False, ssh_args=None):

connectivity_utils.install_client_side_proxy(ssh_proxy_folder)
return

# delete_credentials can only be used by Azure Portal to provide one-click experience on CloudShell.
if delete_credentials and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0":
raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.")
Expand Down Expand Up @@ -190,7 +187,7 @@ def _do_ssh_op(cmd, op_info, op_call):

try:
if op_info.is_arc():
op_info.proxy_path = connectivity_utils.get_client_side_proxy(op_info.ssh_proxy_folder)
op_info.proxy_path = connectivity_utils.install_client_side_proxy(op_info.ssh_proxy_folder)
(op_info.relay_info, op_info.new_service_config) = connectivity_utils.get_relay_information(
cmd, op_info.resource_group_name, op_info.vm_name, op_info.resource_type,
cert_lifetime, op_info.port, op_info.yes_without_prompt)
Expand Down
1 change: 0 additions & 1 deletion src/ssh/azext_ssh/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,3 @@ def remove_invalid_characters_foldername(folder_name):
if c not in const.WINDOWS_INVALID_FOLDERNAME_CHARS:
new_foldername += c
return new_foldername

101 changes: 94 additions & 7 deletions src/ssh/azext_ssh/tests/latest/test_connectivity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import io
import unittest
from unittest import mock
from azext_ssh import custom
from azext_ssh import rdp_utils
from azext_ssh import ssh_utils
from azext_ssh import connectivity_utils

from azure.cli.core import azclierror
Expand Down Expand Up @@ -57,10 +53,101 @@ def test_get_client_os_unsupported(self, mock_plat):
with self.assertRaises(azclierror.BadRequestError):
arch = connectivity_utils._get_client_operating_system()

def test_get_proxy_filename_windows(self):
def test_get_proxy_filename_amd_windows(self):
name = connectivity_utils._get_proxy_filename('Windows', 'amd64')
self.assertEqual(name, 'sshProxy_windows_amd64_1_3_026973.exe')

def test_get_proxy_filename_linux(self):
def test_get_proxy_filename_arm_linux(self):
name = connectivity_utils._get_proxy_filename('Linux', 'arm64')
self.assertEqual(name, 'sshProxy_linux_arm64_1_3_026973')
self.assertEqual(name, 'sshProxy_linux_arm64_1_3_026973')

def test_get_proxy_filename_arm_Darwin(self):
name = connectivity_utils._get_proxy_filename('Darwin', 'arm64')
self.assertEqual(name, 'sshProxy_darwin_arm64_1_3_026973')

def test_get_proxy_filename_386_linuux(self):
name = connectivity_utils._get_proxy_filename('Linux', '386')
self.assertEqual(name, 'sshProxy_linux_386_1_3_026973')

def test_get_proxy_filename_386_darwin(self):
with self.assertRaises(azclierror.BadRequestError):
name = connectivity_utils._get_proxy_filename('Darwin', '386')

@mock.patch('os.path.isfile')
def test_check_proxy_is_installed_fail(self, mock_isfile):
mock_isfile.side_effect = [False, True, True]
with self.assertRaises(azclierror.CLIInternalError):
connectivity_utils._check_proxy_installation("/dir/", "proxy")

@mock.patch('os.path.isfile')
def test_check_proxy_is_installed_sucess(self, mock_isfile):
mock_isfile.side_effect = [True, True, True]
connectivity_utils._check_proxy_installation("/dir/", "proxy")

@mock.patch('os.path.isfile')
def test_check_proxy_is_installed_fail_licenses(self, mock_isfile):
mock_isfile.side_effect = [True, False, False]
connectivity_utils._check_proxy_installation("/dir/", "proxy")

@mock.patch('tarfile.open')
def test_extract_proxy_from_tar(self, mock_open):
mock_tar = mock_open.return_value.__enter__.return_value

mock_file1 = mock.Mock()
mock_file1.name = "dir/sshproxy"
mock_file1.isfile = mock.Mock(return_value=True)

mock_file2 = mock.Mock()
mock_file2.name = "dir/license.txt"
mock_file2.isfile = mock.Mock(return_value=True)

mock_file3 = mock.Mock()
mock_file3.name = "dir/thirdpartynotice.txt"
mock_file3.isfile = mock.Mock(return_value=True)

mock_file4 = mock.Mock()
mock_file4.name = "dir"
mock_file4.isfile = mock.Mock(return_value=False)

mock_tar.getmembers.return_value = [mock_file1, mock_file2, mock_file3, mock_file4]

connectivity_utils._extract_proxy_tar_files("proxy_package.tar.gz", "/tmp/install", "my_proxy")

mock_tar.extractall.assert_called_once_with(members=[mock_file1, mock_file2, mock_file3], path="/tmp/install")

self.assertEquals(mock_file1.name, "my_proxy")
self.assertEquals(mock_file2.name, "license.txt")
self.assertEquals(mock_file3.name, "thirdpartynotice.txt")

@mock.patch('os.path.isfile')
@mock.patch('platform.machine')
@mock.patch('platform.system')
@mock.patch('azext_ssh.connectivity_utils._get_proxy_install_dir')
@mock.patch('os.path.join')
@mock.patch('azext_ssh.file_utils.create_directory')
@mock.patch('azext_ssh.connectivity_utils._download_proxy_from_MCR')
@mock.patch('azext_ssh.connectivity_utils._check_proxy_installation')
def test_install_proxy_create_dir(self, mock_check, mock_download, mock_dir, mock_join, mock_get_proxy_dir, mock_sys, mock_machine, mock_isfile):
mock_machine.return_value = 'aarch64'
mock_sys.return_value = 'linux'
mock_get_proxy_dir.return_value = "/dir/proxy"
mock_isfile.return_value = False

connectivity_utils.install_client_side_proxy(None)

mock_dir.assert_called_once_with("/dir/proxy", "Failed to create client proxy directory \'/dir/proxy\'.")
mock_download.assert_called_once_with("/dir/proxy", "sshProxy_linux_arm64_1_3_026973", "linux", "arm64")
mock_check.assert_called_once_with("/dir/proxy", "sshProxy_linux_arm64_1_3_026973")













7 changes: 3 additions & 4 deletions src/ssh/azext_ssh/tests/latest/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from azure.cli.core import azclierror

'''

class SshCustomCommandTest(unittest.TestCase):

@mock.patch('azext_ssh.custom._do_ssh_op')
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_do_ssh_op_no_public_ip(self, mock_ip, mock_check_files):
mock_ip.assert_called_once_with(cmd, "rg", "vm", False)
mock_op.assert_not_called()

@mock.patch('azext_ssh.connectivity_utils.get_client_side_proxy')
@mock.patch('azext_ssh.connectivity_utils.install_client_side_proxy')
@mock.patch('azext_ssh.connectivity_utils.get_relay_information')
@mock.patch('azext_ssh.ssh_utils.start_ssh_connection')
@mock.patch('azext_ssh.custom._check_or_create_public_private_files')
Expand All @@ -438,7 +438,7 @@ def test_do_ssh_op_arc_local_user(self, mock_get_cert, mock_check_keys, mock_sta
mock_get_cert.assert_not_called()
mock_check_keys.assert_not_called()

@mock.patch('azext_ssh.connectivity_utils.get_client_side_proxy')
@mock.patch('azext_ssh.connectivity_utils.install_client_side_proxy')
@mock.patch('azext_ssh.custom.connectivity_utils.get_relay_information')
@mock.patch('azext_ssh.ssh_utils.get_ssh_cert_principals')
@mock.patch('os.path.join')
Expand Down Expand Up @@ -489,4 +489,3 @@ def test_do_ssh_arc_op_aad_user(self, mock_cert_exp, mock_start_ssh, mock_write_

if __name__ == '__main__':
unittest.main()
'''
3 changes: 1 addition & 2 deletions src/ssh/azext_ssh/tests/latest/test_rdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from azext_ssh import ssh_info
from azext_ssh import ssh_utils

'''

class RDPUtilsTest(unittest.TestCase):
@mock.patch('os.environ.copy')
@mock.patch.object(ssh_utils, 'get_ssh_client_path')
Expand Down Expand Up @@ -96,4 +96,3 @@ def test_start_rdp_connection(self, mock_terminate, mock_rdp, mock_wait, mock_tu

if __name__ == '__main__':
unittest.main()
'''
3 changes: 1 addition & 2 deletions src/ssh/azext_ssh/tests/latest/test_resource_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from azure.cli.core import azclierror

'''

class SshResourceTypeUtilsCommandTest(unittest.TestCase):

@mock.patch('azext_ssh.resource_type_utils._list_types_of_resources_with_provided_name')
Expand Down Expand Up @@ -106,4 +106,3 @@ def test_decide_resource_type_rg_vmware(self, mock_list_types):

if __name__ == '__main__':
unittest.main()
'''
3 changes: 1 addition & 2 deletions src/ssh/azext_ssh/tests/latest/test_rsa_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from azext_ssh import rsa_parser

'''

class RSAParserTest(unittest.TestCase):
def test_rsa_parser_success(self):
public_key_text = 'ssh-rsa ' + self._get_good_key()
Expand Down Expand Up @@ -84,4 +84,3 @@ def _get_good_exponent(self):

if __name__ == '__main__':
unittest.main()
'''
Loading

0 comments on commit 89e0b6c

Please sign in to comment.