Skip to content

Commit

Permalink
Merge pull request #795 from dbcli/default-socket-location
Browse files Browse the repository at this point in the history
try more default socket paths
  • Loading branch information
amjith authored May 23, 2020
2 parents 43140ba + 977e409 commit 7a059a6
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 12 deletions.
44 changes: 42 additions & 2 deletions mycli/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import io
import shutil
from copy import copy
from io import BytesIO, TextIOWrapper
import logging
import os
from os.path import exists
import struct
import sys
from typing import Union

from configobj import ConfigObj, ConfigObjError
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
Expand Down Expand Up @@ -58,13 +61,50 @@ def read_config_file(f, list_values=True):
return config


def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
"""Get a list of configuration files that are included into config_path
with !includedir directive.
"Normal" configs should be passed as file paths. The only exception
is .mylogin which is decoded into a stream. However, it never
contains include directives and so will be ignored by this
function.
"""
if not isinstance(config_file, str) or not os.path.isfile(config_file):
return []
included_configs = []

try:
with open(config_file) as f:
include_directives = filter(
lambda s: s.startswith('!includedir'),
f
)
dirs = map(lambda s: s.strip().split()[-1], include_directives)
dirs = filter(os.path.isdir, dirs)
for dir in dirs:
for filename in os.listdir(dir):
if filename.endswith('.cnf'):
included_configs.append(os.path.join(dir, filename))
except (PermissionError, UnicodeDecodeError):
pass
return included_configs


def read_config_files(files, list_values=True):
"""Read and merge a list of config files."""

config = ConfigObj(list_values=list_values)

for _file in files:
_files = copy(files)
while _files:
_file = _files.pop(0)
_config = read_config_file(_file, list_values=list_values)

# expand includes only if we were able to parse config
# (otherwise we'll just encounter the same errors again)
if config is not None:
_files = get_included_configs(_file) + _files
if bool(_config) is True:
config.merge(_config)
config.filename = _config.filename
Expand Down
18 changes: 10 additions & 8 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import fileinput
from collections import namedtuple
from pwd import getpwuid
from time import time
from datetime import datetime
from random import choice
Expand Down Expand Up @@ -48,7 +49,7 @@
from .lexer import MyCliLexer
from .__init__ import __version__
from .compat import WIN
from .packages.filepaths import dir_path_exists
from .packages.filepaths import dir_path_exists, guess_socket_location

import itertools

Expand Down Expand Up @@ -317,7 +318,7 @@ def read_my_cnf_files(self, files, keys):
"""
cnf = read_config_files(files, list_values=False)

sections = ['client']
sections = ['client', 'mysqld']
if self.login_path and self.login_path != 'client':
sections.append(self.login_path)

Expand Down Expand Up @@ -382,10 +383,11 @@ def connect(self, database='', user='', passwd='', host='', port='',
# Fall back to config values only if user did not specify a value.

database = database or cnf['database']
if port or host:
# Socket interface not supported for SSH connections
if port or host or ssh_host or ssh_port:
socket = ''
else:
socket = socket or cnf['socket']
socket = socket or cnf['socket'] or guess_socket_location()
user = user or cnf['user'] or os.getenv('USER')
host = host or cnf['host']
port = port or cnf['port']
Expand Down Expand Up @@ -430,11 +432,11 @@ def _connect():
raise e

try:
if (socket is host is port is None) and not WIN:
# Try a sensible default socket first (simplifies auth)
# If we get a connection error, try tcp/ip localhost
if not WIN and socket:
socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
self.echo(
f"Connecting to socket {socket}, owned by user {socket_owner}")
try:
socket = '/var/run/mysqld/mysqld.sock'
_connect()
except OperationalError as e:
# These are "Can't open socket" and 2x "Can't connect"
Expand Down
25 changes: 24 additions & 1 deletion mycli/packages/filepaths.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import os
import platform


if os.name == "posix":
if platform.system() == "Darwin":
DEFAULT_SOCKET_DIRS = ("/tmp",)
else:
DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib")
else:
DEFAULT_SOCKET_DIRS = ()


def list_path(root_dir):
"""List directory if exists.
:param dir: str
:param root_dir: str
:return: list
"""
Expand Down Expand Up @@ -81,3 +91,16 @@ def dir_path_exists(path):
"""
return os.path.exists(os.path.dirname(path))


def guess_socket_location():
"""Try to guess the location of the default mysql socket file."""
socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS)
for directory in socket_dirs:
for r, dirs, files in os.walk(directory, topdown=True):
for filename in files:
name, ext = os.path.splitext(filename)
if name.startswith("mysql") and ext in ('.socket', '.sock'):
return os.path.join(r, filename)
dirs[:] = [d for d in dirs if d.startswith("mysql")]
return None
1 change: 0 additions & 1 deletion mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def run(self, statement):
if not cur.nextset() or (not cur.rowcount and cur.description is None):
break


def get_result(self, cursor):
"""Get the current result's data from the cursor."""
title = headers = None
Expand Down

0 comments on commit 7a059a6

Please sign in to comment.