Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ vagrant/.vagrant
.vscode/
*.iml
.pytest_cache/
*.so
83 changes: 60 additions & 23 deletions aredis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,70 @@
from aredis.client import (StrictRedis, StrictRedisCluster)
from aredis.connection import (
Connection,
UnixDomainSocketConnection,
ClusterConnection
)
from aredis.client import StrictRedis, StrictRedisCluster
from aredis.connection import Connection, UnixDomainSocketConnection, ClusterConnection
from aredis.pool import ConnectionPool, ClusterConnectionPool
from aredis.exceptions import (
AuthenticationError, BusyLoadingError, ConnectionError,
DataError, InvalidResponse, PubSubError, ReadOnlyError,
RedisError, ResponseError, TimeoutError, WatchError,
CompressError, ClusterDownException, ClusterCrossSlotError,
CacheError, ClusterDownError, ClusterError, RedisClusterException,
RedisClusterError, ExecAbortError, LockError, NoScriptError
AuthenticationFailureError,
AuthenticationRequiredError,
NoPermissionError,
BusyLoadingError,
ConnectionError,
DataError,
InvalidResponse,
PubSubError,
ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
CompressError,
ClusterDownException,
ClusterCrossSlotError,
CacheError,
ClusterDownError,
ClusterError,
RedisClusterException,
RedisClusterError,
ExecAbortError,
LockError,
NoScriptError,
)


__version__ = '1.1.8'
__version__ = "1.1.8"


VERSION = tuple(map(int, __version__.split('.')))
VERSION = tuple(map(int, __version__.split(".")))


__all__ = [
'StrictRedis', 'StrictRedisCluster',
'Connection', 'UnixDomainSocketConnection', 'ClusterConnection',
'ConnectionPool', 'ClusterConnectionPool',
'AuthenticationError', 'BusyLoadingError', 'ConnectionError', 'DataError',
'InvalidResponse', 'PubSubError', 'ReadOnlyError', 'RedisError',
'ResponseError', 'TimeoutError', 'WatchError',
'CompressError', 'ClusterDownException', 'ClusterCrossSlotError',
'CacheError', 'ClusterDownError', 'ClusterError', 'RedisClusterException',
'RedisClusterError', 'ExecAbortError', 'LockError', 'NoScriptError'
"StrictRedis",
"StrictRedisCluster",
"Connection",
"UnixDomainSocketConnection",
"ClusterConnection",
"ConnectionPool",
"ClusterConnectionPool",
"AuthenticationFailureError",
"AuthenticationRequiredError",
"NoPermissionError",
"BusyLoadingError",
"ConnectionError",
"DataError",
"InvalidResponse",
"PubSubError",
"ReadOnlyError",
"RedisError",
"ResponseError",
"TimeoutError",
"WatchError",
"CompressError",
"ClusterDownException",
"ClusterCrossSlotError",
"CacheError",
"ClusterDownError",
"ClusterError",
"RedisClusterException",
"RedisClusterError",
"ExecAbortError",
"LockError",
"NoScriptError",
]
5 changes: 4 additions & 1 deletion aredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,20 @@ def from_url(cls, url, db=None, **kwargs):
return cls(connection_pool=connection_pool)

def __init__(self, host='localhost', port=6379,
db=0, password=None, stream_timeout=None,
db=0, username=None, password=None, stream_timeout=None,
connect_timeout=None, connection_pool=None,
unix_socket_path=None, encoding='utf-8',
decode_responses=False, ssl=False, ssl_context=None,
ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs=None, ssl_ca_certs=None,
max_connections=None, retry_on_timeout=False,
max_idle_time=0, idle_check_interval=1,
client_name=None,
loop=None, **kwargs):
if not connection_pool:
kwargs = {
'db': db,
'username': username,
'password': password,
'encoding': encoding,
'stream_timeout': stream_timeout,
Expand All @@ -113,6 +115,7 @@ def __init__(self, host='localhost', port=6379,
'decode_responses': decode_responses,
'max_idle_time': max_idle_time,
'idle_check_interval': idle_check_interval,
'client_name': client_name,
'loop': loop
}
# based on input, setup appropriate connection args
Expand Down
38 changes: 28 additions & 10 deletions aredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from io import BytesIO

import aredis.compat
from aredis.exceptions import (ConnectionError, TimeoutError,
from aredis.exceptions import (AuthenticationFailureError, AuthenticationRequiredError,
NoPermissionError, ConnectionError, TimeoutError,
RedisError, ExecAbortError,
BusyLoadingError, NoScriptError,
ReadOnlyError, ResponseError,
Expand Down Expand Up @@ -154,6 +155,9 @@ class BaseParser:
'MOVED': MovedError,
'CLUSTERDOWN': ClusterDownError,
'CROSSSLOT': ClusterCrossSlotError,
'WRONGPASS': AuthenticationFailureError,
'NOAUTH': AuthenticationRequiredError,
'NOPERM': NoPermissionError,
}

def parse_error(self, response):
Expand Down Expand Up @@ -367,11 +371,12 @@ class BaseConnection:
def __init__(self, retry_on_timeout=False, stream_timeout=None,
parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False,
*, loop=None):
*, client_name=None, loop=None):
self._parser = parser_class(reader_read_size)
self._stream_timeout = stream_timeout
self._reader = None
self._writer = None
self.username = ''
self.password = ''
self.db = ''
self.pid = os.getpid()
Expand All @@ -381,6 +386,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None,
self.encoding = encoding
self.decode_responses = decode_responses
self.loop = loop
self.client_name = client_name
# flag to show if a connection is waiting for response
self.awaiting_response = False
self.last_active_at = time.time()
Expand Down Expand Up @@ -433,17 +439,27 @@ async def _connect(self):
async def on_connect(self):
self._parser.on_connect(self)

# if a username and a password is specified, authenticate
if self.username and self.password:
await self.send_command('AUTH', self.username, self.password)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Failed to set username or password')
# if a password is specified, authenticate
if self.password:
elif self.password:
await self.send_command('AUTH', self.password)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Invalid Password')
raise ConnectionError('Failed to set password')

# if a database is specified, switch to it
if self.db:
await self.send_command('SELECT', self.db)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')

if self.client_name is not None:
await self.send_command('CLIENT SETNAME', self.client_name)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Failed to set client name: {}'.format(self.client_name))
self.last_active_at = time.time()

async def read_response(self):
Expand Down Expand Up @@ -569,17 +585,18 @@ def pack_commands(self, commands):
class Connection(BaseConnection):
description = 'Connection<host={host},port={port},db={db}>'

def __init__(self, host='127.0.0.1', port=6379, password=None,
def __init__(self, host='127.0.0.1', port=6379, username=None, password=None,
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False, socket_keepalive=None,
socket_keepalive_options=None, *, loop=None):
socket_keepalive_options=None, *, client_name=None, loop=None):
super(Connection, self).__init__(retry_on_timeout, stream_timeout,
parser_class, reader_read_size,
encoding, decode_responses,
loop=loop)
client_name=client_name, loop=loop)
self.host = host
self.port = port
self.username = username
self.password = password
self.db = db
self.ssl_context = ssl_context
Expand Down Expand Up @@ -623,16 +640,17 @@ async def _connect(self):
class UnixDomainSocketConnection(BaseConnection):
description = "UnixDomainSocketConnection<path={path},db={db}>"

def __init__(self, path='', password=None,
def __init__(self, path='', username=None, password=None,
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False, *, loop=None):
encoding='utf-8', decode_responses=False, *, client_name=None, loop=None):
super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout,
parser_class, reader_read_size,
encoding, decode_responses,
loop=loop)
client_name=client_name, loop=loop)
self.path = path
self.db = db
self.username = username
self.password = password
self.ssl_context = ssl_context
self._connect_timeout = connect_timeout
Expand Down
10 changes: 9 additions & 1 deletion aredis/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@ class RedisError(Exception):
pass


class AuthenticationError(RedisError):
class AuthenticationFailureError(RedisError):
pass


class AuthenticationRequiredError(RedisError):
pass


class NoPermissionError(RedisError):
pass


Expand Down
4 changes: 4 additions & 0 deletions aredis/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,20 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs):
url_options[name] = value[0]

if decode_components:
username = unquote(url.username) if url.username else None
password = unquote(url.password) if url.password else None
path = unquote(url.path) if url.path else None
hostname = unquote(url.hostname) if url.hostname else None
else:
username = url.username
password = url.password
path = url.path
hostname = url.hostname

# We only support redis:// and unix:// schemes.
if url.scheme == 'unix':
url_options.update({
'username': username,
'password': password,
'path': path,
'connection_class': UnixDomainSocketConnection,
Expand All @@ -123,6 +126,7 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs):
url_options.update({
'host': hostname,
'port': int(url.port or 6379),
'username': username,
'password': password,
})

Expand Down
2 changes: 1 addition & 1 deletion tests/client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def skip_python_vsersion_lt(min_version):

@pytest.fixture()
def r(event_loop):
return aredis.StrictRedis(loop=event_loop)
return aredis.StrictRedis(client_name='test', loop=event_loop)


class AsyncMock(Mock):
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def test_client_list_after_client_setname(self, r):
@skip_if_server_version_lt('2.6.9')
@pytest.mark.asyncio(forbid_global_loop=True)
async def test_client_getname(self, r):
assert await r.client_getname() is None
assert await r.client_getname() == 'test'

@skip_if_server_version_lt('2.6.9')
@pytest.mark.asyncio(forbid_global_loop=True)
Expand Down