diff --git a/mycli/config.py b/mycli/config.py index 77475099..e0f2d1fc 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -1,10 +1,13 @@ +import io import shutil +from copy import copy from io import BytesIO, TextIOWrapper import logging import os from os.path import exists import struct import sys +from typing import Union from configobj import ConfigObj, ConfigObjError from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -58,13 +61,50 @@ def read_config_file(f, list_values=True): return config +def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list: + """Get a list of configuration files that are included into config_path + with !includedir directive. + + "Normal" configs should be passed as file paths. The only exception + is .mylogin which is decoded into a stream. However, it never + contains include directives and so will be ignored by this + function. + + """ + if not isinstance(config_file, str) or not os.path.isfile(config_file): + return [] + included_configs = [] + + try: + with open(config_file) as f: + include_directives = filter( + lambda s: s.startswith('!includedir'), + f + ) + dirs = map(lambda s: s.strip().split()[-1], include_directives) + dirs = filter(os.path.isdir, dirs) + for dir in dirs: + for filename in os.listdir(dir): + if filename.endswith('.cnf'): + included_configs.append(os.path.join(dir, filename)) + except (PermissionError, UnicodeDecodeError): + pass + return included_configs + + def read_config_files(files, list_values=True): """Read and merge a list of config files.""" config = ConfigObj(list_values=list_values) - - for _file in files: + _files = copy(files) + while _files: + _file = _files.pop(0) _config = read_config_file(_file, list_values=list_values) + + # expand includes only if we were able to parse config + # (otherwise we'll just encounter the same errors again) + if config is not None: + _files = get_included_configs(_file) + _files if bool(_config) is True: config.merge(_config) config.filename = _config.filename diff --git a/mycli/main.py b/mycli/main.py index 34b5c4f6..d298f202 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -6,6 +6,7 @@ import re import fileinput from collections import namedtuple +from pwd import getpwuid from time import time from datetime import datetime from random import choice @@ -48,7 +49,7 @@ from .lexer import MyCliLexer from .__init__ import __version__ from .compat import WIN -from .packages.filepaths import dir_path_exists +from .packages.filepaths import dir_path_exists, guess_socket_location import itertools @@ -317,7 +318,7 @@ def read_my_cnf_files(self, files, keys): """ cnf = read_config_files(files, list_values=False) - sections = ['client'] + sections = ['client', 'mysqld'] if self.login_path and self.login_path != 'client': sections.append(self.login_path) @@ -382,10 +383,11 @@ def connect(self, database='', user='', passwd='', host='', port='', # Fall back to config values only if user did not specify a value. database = database or cnf['database'] - if port or host: + # Socket interface not supported for SSH connections + if port or host or ssh_host or ssh_port: socket = '' else: - socket = socket or cnf['socket'] + socket = socket or cnf['socket'] or guess_socket_location() user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] port = port or cnf['port'] @@ -430,11 +432,11 @@ def _connect(): raise e try: - if (socket is host is port is None) and not WIN: - # Try a sensible default socket first (simplifies auth) - # If we get a connection error, try tcp/ip localhost + if not WIN and socket: + socket_owner = getpwuid(os.stat(socket).st_uid).pw_name + self.echo( + f"Connecting to socket {socket}, owned by user {socket_owner}") try: - socket = '/var/run/mysqld/mysqld.sock' _connect() except OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py index ac58851d..79fe26dc 100644 --- a/mycli/packages/filepaths.py +++ b/mycli/packages/filepaths.py @@ -1,10 +1,20 @@ import os +import platform + + +if os.name == "posix": + if platform.system() == "Darwin": + DEFAULT_SOCKET_DIRS = ("/tmp",) + else: + DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib") +else: + DEFAULT_SOCKET_DIRS = () def list_path(root_dir): """List directory if exists. - :param dir: str + :param root_dir: str :return: list """ @@ -81,3 +91,16 @@ def dir_path_exists(path): """ return os.path.exists(os.path.dirname(path)) + + +def guess_socket_location(): + """Try to guess the location of the default mysql socket file.""" + socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS) + for directory in socket_dirs: + for r, dirs, files in os.walk(directory, topdown=True): + for filename in files: + name, ext = os.path.splitext(filename) + if name.startswith("mysql") and ext in ('.socket', '.sock'): + return os.path.join(r, filename) + dirs[:] = [d for d in dirs if d.startswith("mysql")] + return None diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 82b9f000..035d98d1 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -191,7 +191,6 @@ def run(self, statement): if not cur.nextset() or (not cur.rowcount and cur.description is None): break - def get_result(self, cursor): """Get the current result's data from the cursor.""" title = headers = None