Skip to content

Commit

Permalink
Use the same SSH connection for main thread and completer thread
Browse files Browse the repository at this point in the history
  • Loading branch information
gfrlv committed May 25, 2020
1 parent 7a059a6 commit 456f86b
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 100 deletions.
8 changes: 5 additions & 3 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ TBD

Features:
---------
* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file.
* Add an option `--list-ssh-config` to list ssh configurations.
* Add an option `--ssh-config-path` to choose ssh configuration path.
* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file (Thanks: [Nathan Huang]).
* Add an option `--list-ssh-config` to list ssh configurations (Thanks: [Nathan Huang]).
* Add an option `--ssh-config-path` to choose ssh configuration path (Thanks: [Nathan Huang]).
* Reuse the same SSH connection in both main thread and completion thread (Thanks: [Georgy Frolov]).


1.21.1
Expand Down Expand Up @@ -757,3 +758,4 @@ Bug Fixes:
[François Pietka]: https://github.com/fpietka
[Frederic Aoustin]: https://github.com/fraoustin
[Georgy Frolov]: https://github.com/pasenor
[Nathan Huang]: https://github.com/hxueh
3 changes: 1 addition & 2 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
e = sqlexecute
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
e.socket, e.charset, e.local_infile, e.ssl,
e.ssh_user, e.ssh_host, e.ssh_port,
e.ssh_password, e.ssh_key_filename)
e.ssh_client)

# If callbacks is a single function then push it into a list.
if callable(callbacks):
Expand Down
64 changes: 18 additions & 46 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
from prompt_toolkit.history import FileHistory
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory

from mycli.packages.ssh_client import create_ssh_client
from .packages.special.main import NO_QUERY
from .packages.prompt_utils import confirm, confirm_destructive_query
from .packages.tabular_output import sql_format
from .packages import special
from .packages import special, ssh_client
from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func
Expand Down Expand Up @@ -63,11 +64,6 @@
from urllib.parse import unquote


try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko

# Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating'])

Expand Down Expand Up @@ -198,6 +194,8 @@ def __init__(self, sqlexecute=None, prompt=None,

self.prompt_app = None

self.ssh_client = None

def register_special_commands(self):
special.register_special_command(self.change_db, 'use',
'\\u', 'Change to a new database.', aliases=('\\u',))
Expand Down Expand Up @@ -358,9 +356,7 @@ def merge_ssl_with_cnf(self, ssl, cnf):
return merged

def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
ssh_password='', ssh_key_filename=''):
socket='', charset='', local_infile='', ssl=None):

cnf = {'database': None,
'user': None,
Expand All @@ -384,7 +380,7 @@ def connect(self, database='', user='', passwd='', host='', port='',

database = database or cnf['database']
# Socket interface not supported for SSH connections
if port or host or ssh_host or ssh_port:
if port or host or self.ssh_client:
socket = ''
else:
socket = socket or cnf['socket'] or guess_socket_location()
Expand Down Expand Up @@ -416,17 +412,15 @@ def _connect():
try:
self.sqlexecute = SQLExecute(
database, user, passwd, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port,
ssh_password, ssh_key_filename
local_infile, ssl, ssh_client=self.ssh_client
)
except OperationalError as e:
if ('Access denied for user' in e.args[1]):
new_passwd = click.prompt('Password', hide_input=True,
show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
ssh_port, ssh_password, ssh_key_filename
charset, local_infile, ssl, ssh_client=self.ssh_client
)
else:
raise e
Expand Down Expand Up @@ -1092,16 +1086,17 @@ def cli(database, user, host, port, socket, password, dbname,
else:
click.secho(alias)
sys.exit(0)

if list_ssh_config:
ssh_config = read_ssh_config(ssh_config_path)
for host in ssh_config.get_hostnames():
hosts = ssh_client.get_config_hosts(ssh_config_path)
for host, hostname in hosts.items():
if verbose:
host_config = ssh_config.lookup(host)
click.secho("{} : {}".format(
host, host_config.get('hostname')))
host, hostname))
else:
click.secho(host)
sys.exit(0)

# Choose which ever one has a valid value.
database = dbname or database

Expand Down Expand Up @@ -1153,7 +1148,7 @@ def cli(database, user, host, port, socket, password, dbname,
port = uri.port

if ssh_config_host:
ssh_config = read_ssh_config(
ssh_config = ssh_client.read_config_file(
ssh_config_path
).lookup(ssh_config_host)
ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
Expand All @@ -1164,7 +1159,10 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
'identityfile', [None])[0]

ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
if ssh_host:
mycli.ssh_client = create_ssh_client(
ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename
)

mycli.connect(
database=database,
Expand All @@ -1175,11 +1173,6 @@ def cli(database, user, host, port, socket, password, dbname,
socket=socket,
local_infile=local_infile,
ssl=ssl,
ssh_user=ssh_user,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename
)

mycli.logger.debug('Launch Params: \n'
Expand Down Expand Up @@ -1298,26 +1291,5 @@ def edit_and_execute(event):
buff.open_in_editor(validate_and_handle=False)


def read_ssh_config(ssh_config_path):
ssh_config = paramiko.config.SSHConfig()
try:
with open(ssh_config_path) as f:
ssh_config.parse(f)
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
except Exception as err:
click.secho(
f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
err=True, fg='red'
)
sys.exit(1)
except FileNotFoundError as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)
else:
return ssh_config


if __name__ == "__main__":
cli()
1 change: 1 addition & 0 deletions mycli/packages/ssh_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file
46 changes: 46 additions & 0 deletions mycli/packages/ssh_client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""A very thin wrapper around paramiko, mostly to keep all SSH-related
functionality in one place."""
from io import open

try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko


class SSHException(Exception):
pass


def get_config_hosts(config_path):
config = read_config_file(config_path)
return {
host: config.lookup(host).get("hostname") for host in config.get_hostnames()
}


def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename
)
return client


def read_config_file(config_path) -> paramiko.SSHConfig:
ssh_config = paramiko.config.SSHConfig()
try:
with open(config_path) as f:
ssh_config.parse(f)
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
except Exception as err:
raise SSHException(
f"Could not parse SSH configuration file {config_path}:\n{err} ",
)
except FileNotFoundError as e:
raise SSHException(str(e))
return ssh_config
50 changes: 11 additions & 39 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from pymysql.converters import (convert_mysql_timestamp, convert_datetime,
convert_timedelta, convert_date, conversions,
decoders)
try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko


_logger = logging.getLogger(__name__)

Expand All @@ -18,6 +15,7 @@
FIELD_TYPE.NULL: type(None)
})


class SQLExecute(object):

databases_query = '''SHOW DATABASES'''
Expand All @@ -41,8 +39,8 @@ class SQLExecute(object):
order by table_name,ordinal_position'''

def __init__(self, database, user, password, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename):
local_infile, ssl,
ssh_client=None):
self.dbname = database
self.user = user
self.password = password
Expand All @@ -54,17 +52,12 @@ def __init__(self, database, user, password, host, port, socket, charset,
self.ssl = ssl
self._server_type = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
self.ssh_port = ssh_port
self.ssh_password = ssh_password
self.ssh_key_filename = ssh_key_filename
self.ssh_client = ssh_client

self.connect()

def connect(self, database=None, user=None, password=None, host=None,
port=None, socket=None, charset=None, local_infile=None,
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
ssh_password=None, ssh_key_filename=None):
port=None, socket=None, charset=None, local_infile=None, ssl=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
Expand All @@ -74,11 +67,6 @@ def connect(self, database=None, user=None, password=None, host=None,
charset = (charset or self.charset)
local_infile = (local_infile or self.local_infile)
ssl = (ssl or self.ssl)
ssh_user = (ssh_user or self.ssh_user)
ssh_host = (ssh_host or self.ssh_host)
ssh_port = (ssh_port or self.ssh_port)
ssh_password = (ssh_password or self.ssh_password)
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
_logger.debug(
'Connection DB Params: \n'
'\tdatabase: %r'
Expand All @@ -88,14 +76,8 @@ def connect(self, database=None, user=None, password=None, host=None,
'\tsocket: %r'
'\tcharset: %r'
'\tlocal_infile: %r'
'\tssl: %r'
'\tssh_user: %r'
'\tssh_host: %r'
'\tssh_port: %r'
'\tssh_password: %r'
'\tssh_key_filename: %r',
'\tssl: %r',
db, user, host, port, socket, charset, local_infile, ssl,
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename
)
conv = conversions.copy()
conv.update({
Expand All @@ -107,26 +89,16 @@ def connect(self, database=None, user=None, password=None, host=None,

defer_connect = False

if ssh_host:
defer_connect = True

conn = pymysql.connect(
database=db, user=user, password=password, host=host, port=port,
unix_socket=socket, use_unicode=True, charset=charset,
autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE,
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
defer_connect=defer_connect
defer_connect=self.ssh_client is not None
)

if ssh_host:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, ssh_password,
key_filename=ssh_key_filename
)
chan = client.get_transport().open_channel(
if self.ssh_client:
chan = self.ssh_client.get_transport().open_channel(
'direct-tcpip',
(host, port),
('0.0.0.0', 0),
Expand Down
10 changes: 8 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest

from mycli.packages.ssh_client import create_ssh_client
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
import mycli.sqlexecute
Expand All @@ -21,9 +23,13 @@ def cursor(connection):

@pytest.fixture
def executor(connection):
if SSH_HOST:
ssh_client = create_ssh_client(SSH_HOST, SSH_PORT, SSH_USER)
else:
ssh_client = None

return mycli.sqlexecute.SQLExecute(
database='_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
local_infile=False, ssl=None, ssh_client=ssh_client
)
Loading

0 comments on commit 456f86b

Please sign in to comment.