-
Notifications
You must be signed in to change notification settings - Fork 665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reuse ssh connection #869
base: main
Are you sure you want to change the base?
Reuse ssh connection #869
Changes from 3 commits
456f86b
3db1e2d
57b7520
664e62c
deff68f
d9c604d
6392029
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,10 +34,11 @@ | |
from prompt_toolkit.history import FileHistory | ||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory | ||
|
||
from mycli.packages.ssh_client import create_ssh_client | ||
from .packages.special.main import NO_QUERY | ||
from .packages.prompt_utils import confirm, confirm_destructive_query | ||
from .packages.tabular_output import sql_format | ||
from .packages import special | ||
from .packages import special, ssh_client | ||
from .packages.special.favoritequeries import FavoriteQueries | ||
from .sqlcompleter import SQLCompleter | ||
from .clitoolbar import create_toolbar_tokens_func | ||
|
@@ -66,11 +67,6 @@ | |
from urllib.parse import unquote | ||
|
||
|
||
try: | ||
import paramiko | ||
except ImportError: | ||
from mycli.packages.paramiko_stub import paramiko | ||
|
||
# Query tuples are used for maintaining history | ||
Query = namedtuple('Query', ['query', 'successful', 'mutating']) | ||
|
||
|
@@ -201,6 +197,8 @@ def __init__(self, sqlexecute=None, prompt=None, | |
|
||
self.prompt_app = None | ||
|
||
self.ssh_client = None | ||
|
||
def register_special_commands(self): | ||
special.register_special_command(self.change_db, 'use', | ||
'\\u', 'Change to a new database.', aliases=('\\u',)) | ||
|
@@ -361,9 +359,7 @@ def merge_ssl_with_cnf(self, ssl, cnf): | |
return merged | ||
|
||
def connect(self, database='', user='', passwd='', host='', port='', | ||
socket='', charset='', local_infile='', ssl='', | ||
ssh_user='', ssh_host='', ssh_port='', | ||
ssh_password='', ssh_key_filename='', init_command=''): | ||
socket='', charset='', local_infile='', ssl=None, init_command=''): | ||
|
||
cnf = {'database': None, | ||
'user': None, | ||
|
@@ -387,7 +383,7 @@ def connect(self, database='', user='', passwd='', host='', port='', | |
|
||
database = database or cnf['database'] | ||
# Socket interface not supported for SSH connections | ||
if port or host or ssh_host or ssh_port: | ||
if port or host or self.ssh_client: | ||
socket = '' | ||
else: | ||
socket = socket or cnf['socket'] or guess_socket_location() | ||
|
@@ -419,17 +415,16 @@ def _connect(): | |
try: | ||
self.sqlexecute = SQLExecute( | ||
database, user, passwd, host, port, socket, charset, | ||
local_infile, ssl, ssh_user, ssh_host, ssh_port, | ||
ssh_password, ssh_key_filename, init_command | ||
local_infile, ssl, init_command, ssh_client=self.ssh_client | ||
) | ||
except OperationalError as e: | ||
if ('Access denied for user' in e.args[1]): | ||
new_passwd = click.prompt('Password', hide_input=True, | ||
show_default=False, type=str, err=True) | ||
self.sqlexecute = SQLExecute( | ||
database, user, new_passwd, host, port, socket, | ||
charset, local_infile, ssl, ssh_user, ssh_host, | ||
ssh_port, ssh_password, ssh_key_filename, init_command | ||
charset, local_infile, ssl, init_command, | ||
ssh_client=self.ssh_client | ||
) | ||
else: | ||
raise e | ||
|
@@ -1098,16 +1093,22 @@ def cli(database, user, host, port, socket, password, dbname, | |
else: | ||
click.secho(alias) | ||
sys.exit(0) | ||
|
||
if list_ssh_config: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Side note: the list_ssh_config feature is unreliable if the SSH config file uses certain features such as Match. Listing the hosts is not really in the scope of a tool such as mycli, and we ought to consider removing it. |
||
ssh_config = read_ssh_config(ssh_config_path) | ||
for host in ssh_config.get_hostnames(): | ||
try: | ||
hosts = ssh_client.get_config_hosts(ssh_config_path) | ||
except ssh_client.SSHException as e: | ||
click.secho(str(e), err=True, fg='red') | ||
sys.exit(1) | ||
|
||
for host, hostname in hosts.items(): | ||
if verbose: | ||
host_config = ssh_config.lookup(host) | ||
click.secho("{} : {}".format( | ||
host, host_config.get('hostname'))) | ||
host, hostname)) | ||
else: | ||
click.secho(host) | ||
sys.exit(0) | ||
|
||
# Choose which ever one has a valid value. | ||
database = dbname or database | ||
|
||
|
@@ -1159,9 +1160,14 @@ def cli(database, user, host, port, socket, password, dbname, | |
port = uri.port | ||
|
||
if ssh_config_host: | ||
ssh_config = read_ssh_config( | ||
ssh_config_path | ||
).lookup(ssh_config_host) | ||
try: | ||
ssh_config = ssh_client.read_config_file( | ||
ssh_config_path | ||
).lookup(ssh_config_host) | ||
except ssh_client.SSHException as e: | ||
click.secho(str(e), err=True, fg='red') | ||
sys.exit(1) | ||
|
||
ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') | ||
ssh_user = ssh_user if ssh_user else ssh_config.get('user') | ||
if ssh_config.get('port') and ssh_port == 22: | ||
|
@@ -1170,7 +1176,10 @@ def cli(database, user, host, port, socket, password, dbname, | |
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( | ||
'identityfile', [None])[0] | ||
|
||
ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) | ||
if ssh_host: | ||
mycli.ssh_client = create_ssh_client( | ||
ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename | ||
) | ||
|
||
mycli.connect( | ||
database=database, | ||
|
@@ -1181,12 +1190,7 @@ def cli(database, user, host, port, socket, password, dbname, | |
socket=socket, | ||
local_infile=local_infile, | ||
ssl=ssl, | ||
ssh_user=ssh_user, | ||
ssh_host=ssh_host, | ||
ssh_port=ssh_port, | ||
ssh_password=ssh_password, | ||
ssh_key_filename=ssh_key_filename, | ||
init_command=init_command | ||
init_command=init_command, | ||
) | ||
|
||
mycli.logger.debug('Launch Params: \n' | ||
|
@@ -1305,26 +1309,5 @@ def edit_and_execute(event): | |
buff.open_in_editor(validate_and_handle=False) | ||
|
||
|
||
def read_ssh_config(ssh_config_path): | ||
ssh_config = paramiko.config.SSHConfig() | ||
try: | ||
with open(ssh_config_path) as f: | ||
ssh_config.parse(f) | ||
# Paramiko prior to version 2.7 raises Exception on parse errors. | ||
# In 2.7 it has become paramiko.ssh_exception.SSHException, | ||
# but let's catch everything for compatibility | ||
except Exception as err: | ||
click.secho( | ||
f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', | ||
err=True, fg='red' | ||
) | ||
sys.exit(1) | ||
except FileNotFoundError as e: | ||
click.secho(str(e), err=True, fg='red') | ||
sys.exit(1) | ||
else: | ||
return ssh_config | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""A very thin wrapper around paramiko, mostly to keep all SSH-related | ||
functionality in one place.""" | ||
from io import open | ||
|
||
try: | ||
import paramiko | ||
except ImportError: | ||
from mycli.packages.paramiko_stub import paramiko | ||
|
||
|
||
class SSHException(Exception): | ||
pass | ||
|
||
|
||
def get_config_hosts(config_path): | ||
config = read_config_file(config_path) | ||
return { | ||
host: config.lookup(host).get("hostname") for host in config.get_hostnames() | ||
} | ||
|
||
|
||
def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient: | ||
client = paramiko.SSHClient() | ||
client.load_system_host_keys() | ||
client.set_missing_host_key_policy(paramiko.WarningPolicy()) | ||
client.connect( | ||
ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename | ||
) | ||
return client | ||
|
||
|
||
def read_config_file(config_path) -> paramiko.SSHConfig: | ||
ssh_config = paramiko.config.SSHConfig() | ||
try: | ||
with open(config_path) as f: | ||
ssh_config.parse(f) | ||
# Paramiko prior to version 2.7 raises Exception on parse errors. | ||
# In 2.7 it has become paramiko.ssh_exception.SSHException, | ||
# but let's catch everything for compatibility | ||
except Exception as err: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
raise SSHException( | ||
f"Could not parse SSH configuration file {config_path}:\n{err} ", | ||
) | ||
except FileNotFoundError as e: | ||
raise SSHException(str(e)) | ||
return ssh_config |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,7 @@ | |
from pymysql.converters import (convert_datetime, | ||
convert_timedelta, convert_date, conversions, | ||
decoders) | ||
try: | ||
import paramiko | ||
except ImportError: | ||
from mycli.packages.paramiko_stub import paramiko | ||
|
||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
@@ -18,6 +15,7 @@ | |
FIELD_TYPE.NULL: type(None) | ||
}) | ||
|
||
|
||
class SQLExecute(object): | ||
|
||
databases_query = '''SHOW DATABASES''' | ||
|
@@ -41,8 +39,8 @@ class SQLExecute(object): | |
order by table_name,ordinal_position''' | ||
|
||
def __init__(self, database, user, password, host, port, socket, charset, | ||
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, | ||
ssh_key_filename, init_command=None): | ||
local_infile, ssl, init_command=None, | ||
ssh_client=None): | ||
self.dbname = database | ||
self.user = user | ||
self.password = password | ||
|
@@ -54,18 +52,14 @@ def __init__(self, database, user, password, host, port, socket, charset, | |
self.ssl = ssl | ||
self._server_type = None | ||
self.connection_id = None | ||
self.ssh_user = ssh_user | ||
self.ssh_host = ssh_host | ||
self.ssh_port = ssh_port | ||
self.ssh_password = ssh_password | ||
self.ssh_key_filename = ssh_key_filename | ||
self.init_command = init_command | ||
self.ssh_client = ssh_client | ||
|
||
self.connect() | ||
|
||
def connect(self, database=None, user=None, password=None, host=None, | ||
port=None, socket=None, charset=None, local_infile=None, | ||
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, | ||
ssh_password=None, ssh_key_filename=None, init_command=None): | ||
port=None, socket=None, charset=None, local_infile=None, ssl=None, | ||
ssh_client=None, init_command=None): | ||
db = (database or self.dbname) | ||
user = (user or self.user) | ||
password = (password or self.password) | ||
|
@@ -75,11 +69,6 @@ def connect(self, database=None, user=None, password=None, host=None, | |
charset = (charset or self.charset) | ||
local_infile = (local_infile or self.local_infile) | ||
ssl = (ssl or self.ssl) | ||
ssh_user = (ssh_user or self.ssh_user) | ||
ssh_host = (ssh_host or self.ssh_host) | ||
ssh_port = (ssh_port or self.ssh_port) | ||
ssh_password = (ssh_password or self.ssh_password) | ||
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) | ||
init_command = (init_command or self.init_command) | ||
_logger.debug( | ||
'Connection DB Params: \n' | ||
|
@@ -90,15 +79,9 @@ def connect(self, database=None, user=None, password=None, host=None, | |
'\tsocket: %r' | ||
'\tcharset: %r' | ||
'\tlocal_infile: %r' | ||
'\tssl: %r' | ||
'\tssh_user: %r' | ||
'\tssh_host: %r' | ||
'\tssh_port: %r' | ||
'\tssh_password: %r' | ||
'\tssh_key_filename: %r' | ||
'\tssl: %r', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could still log these values, if we wanted to, right? |
||
'\tinit_command: %r', | ||
db, user, host, port, socket, charset, local_infile, ssl, | ||
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, | ||
init_command | ||
) | ||
conv = conversions.copy() | ||
|
@@ -111,9 +94,6 @@ def connect(self, database=None, user=None, password=None, host=None, | |
|
||
defer_connect = False | ||
|
||
if ssh_host: | ||
defer_connect = True | ||
|
||
client_flag = pymysql.constants.CLIENT.INTERACTIVE | ||
if init_command and len(list(special.split_queries(init_command))) > 1: | ||
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS | ||
|
@@ -123,18 +103,12 @@ def connect(self, database=None, user=None, password=None, host=None, | |
unix_socket=socket, use_unicode=True, charset=charset, | ||
autocommit=True, client_flag=client_flag, | ||
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", | ||
defer_connect=defer_connect, init_command=init_command | ||
defer_connect=self.ssh_client is not None, | ||
init_command=init_command | ||
) | ||
|
||
if ssh_host: | ||
client = paramiko.SSHClient() | ||
client.load_system_host_keys() | ||
client.set_missing_host_key_policy(paramiko.WarningPolicy()) | ||
client.connect( | ||
ssh_host, ssh_port, ssh_user, ssh_password, | ||
key_filename=ssh_key_filename | ||
) | ||
chan = client.get_transport().open_channel( | ||
if self.ssh_client: | ||
chan = self.ssh_client.get_transport().open_channel( | ||
'direct-tcpip', | ||
(host, port), | ||
('0.0.0.0', 0), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On rebase this should become a little different, like