diff --git a/openwisp_controller/connection/connectors/ssh.py b/openwisp_controller/connection/connectors/ssh.py index e4304a717..5fbfb3f1a 100644 --- a/openwisp_controller/connection/connectors/ssh.py +++ b/openwisp_controller/connection/connectors/ssh.py @@ -1,13 +1,11 @@ import logging import socket import sys -from io import BytesIO import paramiko from django.utils.functional import cached_property from jsonschema import validate from jsonschema.exceptions import ValidationError as SchemaError -from scp import SCPClient if sys.version_info.major > 2: # pragma: nocover from io import StringIO diff --git a/openwisp_controller/connection/tests/test_admin.py b/openwisp_controller/connection/tests/test_admin.py index 70e2bce0c..3133a5c01 100644 --- a/openwisp_controller/connection/tests/test_admin.py +++ b/openwisp_controller/connection/tests/test_admin.py @@ -1,12 +1,9 @@ -import json - from django.test import TestCase from django.urls import reverse from ...config.models import Template from ...config.tests.test_admin import TestAdmin as TestConfigAdmin from ...tests.utils import TestAdminMixin -from .. import settings as app_settings from ..models import Credentials, DeviceConnection, DeviceIp from .base import CreateConnectionsMixin, SshServerMixin @@ -81,7 +78,7 @@ def test_connection_credentials_fk_queryset(self): data = self._create_multitenancy_test_env() self._test_multitenant_admin( url=reverse('admin:config_device_add'), - visible=[data['cred1'].name + ' (SSH)'], - hidden=[data['cred2'].name + ' (SSH)', data['cred3_inactive']], + visible=[str(data['cred1'].name) + str(" (SSH)")], + hidden=[str(data['cred2'].name) + str(" (SSH)"), data['cred3_inactive']], select_widget=True ) diff --git a/openwisp_controller/connection/tests/test_models.py b/openwisp_controller/connection/tests/test_models.py index e1e950455..48eabd5dd 100644 --- a/openwisp_controller/connection/tests/test_models.py +++ b/openwisp_controller/connection/tests/test_models.py @@ -1,6 +1,7 @@ import paramiko from django.core.exceptions import ValidationError from django.test import TestCase +from mock import Mock, patch from openwisp_users.models import Organization @@ -274,14 +275,55 @@ def test_device_config_update(self): } ] } - # here you can start capturing standard output - # see how it's done here: https://github.com/ncouture/MockSSH/blob/master/tests/test_mock_cisco.py#L14-L20 c.full_clean() c.save() - # at this point the update_config() method will be triggered - # you need to create a command in the mocked SSH server that - # just prints something like: mock-ssh-server: openwisp-config restarted - # so you can then do: - # self.assertIn('mock-ssh-server: openwisp-config restarted', sdout) - # if update_config() finishes successfully then status will be set to "applied" - self.assertEqual(device.status, 'applied') + device.refresh_from_db() + self.assertEqual(device.status, 'modified') + c.config = { + 'interfaces': [ + { + 'name': 'eth10', + 'type': 'ethernet', + 'addresses': [ + { + 'family': 'ipv4', + 'proto': 'dhcp' + } + ] + } + ] + } + stdin = Mock() + stdout = Mock() + stderr = Mock() + stdout.read().decode('utf8').strip.return_value = \ + 'Mock Object: executed command successfully' + stdout.channel.recv_exit_status.return_value = 0 + stderr.read().decode('utf8').strip.return_value = '' + with patch('paramiko.SSHClient.exec_command') as mock_object: + mock_object.return_value = [stdin, stdout, stderr] + c.full_clean() + c.save() + device.refresh_from_db() + self.assertEqual(paramiko.SSHClient.exec_command.call_count, 1) + actual = str(paramiko.SSHClient.exec_command.call_args) + expected = str("call('/etc/init.d/openwisp_config restart')") + self.assertEqual(expected, actual) + self.assertEqual(device.status, 'applied') + + def test_ssh_exec_exist_status(self): + ckey = self._create_credentials_with_key(port=self.ssh_server.port) + dc = self._create_device_connection(credentials=ckey) + self._create_device_ip(address=self.ssh_server.host, + device=dc.device) + dc.connect() + _, exit_status = dc.connector_instance.exec_command('ls') + self.assertEqual(exit_status, 0) + + def test_ssh_exec_exception(self): + ckey = self._create_credentials_with_key(port=self.ssh_server.port) + dc = self._create_device_connection(credentials=ckey) + self._create_device_ip(address=self.ssh_server.host, + device=dc.device) + with self.assertRaises(Exception): + dc.connector_instance.exec_command('ls')