Skip to content

Commit

Permalink
Add the SshFabricTransport plugin
Browse files Browse the repository at this point in the history
This transport plugin subclasses the `SshTransport` plugin but replaces
the managing of the connection with the `fabric` library. Under the hood
this library still uses `paramiko`, however, it is a lot clevered in
automatically determining the correct connection configuration. For
example, it will automatically search for SSH configurations on the
machine. This makes it easier for the user to setup the transport to
connect to a particular host if they can connect to it through a direct
`ssh` invocation`. Since `fabric` still uses `paramiko` under the hood,
just like the `SshTransport`, most of the implementation can be reused.
  • Loading branch information
sphuber committed Mar 25, 2024
1 parent 65786a6 commit 55da99b
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 0 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- click~=8.1
- disk-objectstore~=1.0
- docstring_parser
- fabric~=3.0
- get-annotations~=0.1
- python-graphviz~=0.19
- ipython>=7
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
'click~=8.1',
'disk-objectstore~=1.0',
'docstring-parser',
'fabric~=3.0',
'get-annotations~=0.1;python_version<"3.10"',
'graphviz~=0.19',
'ipython>=7',
Expand Down Expand Up @@ -175,6 +176,7 @@ requires-python = '>=3.9'
[project.entry-points.'aiida.transports']
'core.local' = 'aiida.transports.plugins.local:LocalTransport'
'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport'
'core.ssh_fabric' = 'aiida.transports.plugins.ssh_fabric:SshFabricTransport'

[project.entry-points.'aiida.workflows']
'core.arithmetic.add_multiply' = 'aiida.workflows.arithmetic.add_multiply:add_multiply'
Expand Down Expand Up @@ -310,6 +312,7 @@ module = [
'circus.*',
'click_spinner.*',
'docutils.*',
'fabric.*',
'flask_cors.*',
'flask_restful.*',
'get_annotations.*',
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ docutils==0.20.1
emmet-core==0.57.1
exceptiongroup==1.1.1
executing==1.2.0
fabric==3.2.2
fastjsonschema==2.17.1
flask==2.3.2
flask-cors==3.0.10
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ docstring-parser==0.15
docutils==0.20.1
emmet-core==0.57.1
executing==1.2.0
fabric==3.2.2
fastjsonschema==2.17.1
flask==2.3.2
flask-cors==3.0.10
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ disk-objectstore==1.0.0
docstring-parser==0.15
docutils==0.20.1
executing==2.0.0
fabric==3.2.2
fastjsonschema==2.18.1
flask==2.3.3
flask-cors==3.0.10
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ docutils==0.20.1
emmet-core==0.57.1
exceptiongroup==1.1.1
executing==1.2.0
fabric==3.2.2
fastjsonschema==2.17.1
flask==2.3.2
flask-cors==3.0.10
Expand Down
1 change: 1 addition & 0 deletions src/aiida/transports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .transport import *

__all__ = (
'SshFabricTransport',
'SshTransport',
'Transport',
'convert_to_bool',
Expand Down
2 changes: 2 additions & 0 deletions src/aiida/transports/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# fmt: off

from .ssh import *
from .ssh_fabric import *

__all__ = (
'SshFabricTransport',
'SshTransport',
'convert_to_bool',
'parse_sshconfig',
Expand Down
73 changes: 73 additions & 0 deletions src/aiida/transports/plugins/ssh_fabric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Transport plugin for SSH connection using ``fabric``.
This subclasses the ``SshTransport`` plugin to replace the connection configuration using the ``fabric`` package. This
makes the configuration significantly easier for the user since ``fabric`` automatically guesses connection parameters
by parsing available SSH configuration files on the system. Since ``fabric`` uses ``paramiko`` under the hood, which is
the library used by the ``SshTransport``, most of the implementation can be reused. The ``paramiko.SshClient`` is
exposed on the ``fabric.Connection`` instance and is proxied to the ``_client`` attribute of the class where the
``SshTransport`` base class expects to find it.
"""
import fabric

from aiida.common.exceptions import InvalidOperation

from ..transport import Transport
from .ssh import SshTransport

__all__ = ('SshFabricTransport',)


class SshFabricTransport(SshTransport):
"""Transport plugin for SSH connections with highly-automated connection configuration.
The connection is made using the ``fabric`` package which will try to use available SSH configuration files on the
system to guess the correct configuration for the connection to the relevant hostname.
"""

_valid_auth_options = []

def __init__(self, *args, **kwargs):
"""Construct a new instance."""
Transport.__init__(self, *args, **kwargs) # Skip the ``SshTransport`` constructor
self._connection = fabric.Connection(self.hostname)
self._client = self._connection.client
self._sftp = None

def open(self):
"""Open the connection."""
if not self._connection.is_connected:
self._sftp = self._connection.sftp() # This opens the fabric connection ``self._connection`` as well
self._sftp.chdir('.') # Needed to make sure sftp.getcwd() returns a valid path
self._is_open = True

def close(self):
"""Close the connection.
This will close the SFTP channel and the ``fabric`` connection.
:raise aiida.common.InvalidOperation: If the channel is already open.
"""
if not self._is_open:
raise InvalidOperation('Cannot close the transport: it is already closed')

if self._sftp:
self._sftp.close()
self._connection.close()
self._is_open = False

def __str__(self):
"""Return a string representation of the transport instance."""
status = 'OPEN' if self._is_open else 'CLOSED'
return f'{self._connection.user}@{self.hostname}:{self._connection.port} [{status}]'

def gotocomputer_command(self, remotedir):
"""Return the command string to connect to the remote and change directory to ``remotedir``."""
return f'ssh -t {self.hostname} {self._gotocomputer_string(remotedir)}'
14 changes: 14 additions & 0 deletions tests/cmdline/commands/test_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,17 @@ def time_use_login_shell(authinfo, auth_params, use_login_shell, iterations) ->
result = run_cli_command(computer_test, [aiida_localhost.label], use_subprocess=False)
assert 'Success: all 6 tests succeeded' in result.output
assert 'computer is configured to use a login shell, which is slower compared to a normal shell' in result.output


def test_computer_ssh_fabric(run_cli_command, aiida_computer):
"""Test setup of computer with ``core.ssh_fabric`` entry point.
The configure step should only require the common shared options ``safe_interval`` and ``use_login_shell``.
"""
computer = aiida_computer(transport_type='core.ssh_fabric').store()
assert not computer.is_configured

# It is important that no other options (except for `--safe-interval`) have to be specified for this transport type.
options = ['core.ssh_fabric', computer.uuid, '--non-interactive', '--safe-interval', '0']
run_cli_command(computer_configure, options, use_subprocess=False)
assert computer.is_configured
2 changes: 2 additions & 0 deletions tests/transports/test_all_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def custom_transport(request) -> Transport:
"""Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``."""
if request.param == 'core.ssh':
kwargs = {'machine': 'localhost', 'timeout': 30, 'load_system_host_keys': True, 'key_policy': 'AutoAddPolicy'}
elif request.param == 'core.ssh_fabric':
kwargs = {'machine': 'localhost'}
else:
kwargs = {}

Expand Down

0 comments on commit 55da99b

Please sign in to comment.