Skip to content

Commit caae469

Browse files
committed
restructure of type hint system of pools for better mypy compatibility
1 parent 537a7b3 commit caae469

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

django_valkey/async_cache/pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from django_valkey.typings import AsyncDefaultParserT
1313

1414

15-
class AsyncConnectionFactory(BaseConnectionFactory[AValkey, ConnectionPool]):
15+
class AsyncConnectionFactory(BaseConnectionFactory[AValkey]):
1616
path_pool_cls = "valkey.asyncio.connection.ConnectionPool"
1717
path_base_cls = "valkey.asyncio.client.Valkey"
1818

django_valkey/base_pool.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,24 @@
33
from django.core.exceptions import ImproperlyConfigured
44
from django.utils.module_loading import import_string
55

6-
Pool = TypeVar("Pool")
76
Base = TypeVar("Base")
87

98

10-
class BaseConnectionFactory(Generic[Base, Pool]):
9+
class BaseConnectionFactory(Generic[Base]):
1110
# Store connection pool by cache backend options.
1211
#
1312
# _pools is a process-global, as otherwise _pools is cleared every time
1413
# ConnectionFactory is instantiated, as Django creates new cache client
1514
# (DefaultClient) instance for every request.
1615

17-
_pools: dict[str, Pool | Any] = {}
16+
_pools: dict[str, type] = {} # dict[str, Pool]
1817

1918
def __init__(self, options: dict) -> None:
20-
pool_cls_path: str = options.get("CONNECTION_POOL_CLASS", self.path_pool_cls)
21-
self.pool_cls: type[Pool] = import_string(pool_cls_path)
19+
pool_cls_path: str = options.get("CONNECTION_POOL_CLASS", self.path_pool_cls) # type: ignore[attr-defined]
20+
self.pool_cls = import_string(pool_cls_path) # type[Pool]
2221
self.pool_cls_kwargs = options.get("CONNECTION_POOL_KWARGS", {})
2322

24-
base_client_cls_path: str = options.get("BASE_CLIENT_CLASS", self.path_base_cls)
23+
base_client_cls_path: str = options.get("BASE_CLIENT_CLASS", self.path_base_cls) # type: ignore[attr-defined]
2524
self.base_client_cls: type[Base] = import_string(base_client_cls_path)
2625
self.base_client_cls_kwargs = options.get("BASE_CLIENT_KWARGS", {})
2726

@@ -42,7 +41,7 @@ def make_connection_params(self, url: str | None) -> dict:
4241
# TODO: do we need to check for existence?
4342
if socket_timeout:
4443
if not isinstance(socket_timeout, (int, float)):
45-
error_message = "Socket timeout should be float or integer"
44+
error_message = "Socket timeout should be float or integer" # type: ignore[unreachable]
4645
raise ImproperlyConfigured(error_message)
4746
kwargs["socket_timeout"] = socket_timeout
4847

@@ -51,7 +50,7 @@ def make_connection_params(self, url: str | None) -> dict:
5150
)
5251
if socket_connect_timeout:
5352
if not isinstance(socket_connect_timeout, (int, float)):
54-
error_message = "Socket connect timeout should be float or integer"
53+
error_message = "Socket connect timeout should be float or integer" # type: ignore[unreachable]
5554
raise ImproperlyConfigured(error_message)
5655
kwargs["socket_connect_timeout"] = socket_connect_timeout
5756

@@ -61,31 +60,34 @@ def make_connection_params(self, url: str | None) -> dict:
6160

6261
return kwargs
6362

64-
def get_connection_pool(self, params: dict) -> Pool:
63+
def get_connection_pool(self, params: dict):
6564
"""
6665
Given a connection parameters, return a new
6766
connection pool for them.
6867
6968
Overwrite this method if you want a custom
7069
behavior on creating connection pool.
70+
71+
:returns: the connection pool
7172
"""
72-
cp_params = params
73-
cp_params.update(self.pool_cls_kwargs)
74-
pool = self.pool_cls.from_url(**cp_params)
73+
params.update(self.pool_cls_kwargs)
74+
pool = self.pool_cls.from_url(**params)
7575

7676
if pool.connection_kwargs.get("password", None) is None:
7777
pool.connection_kwargs["password"] = params.get("password")
7878
pool.reset()
7979

8080
return pool
8181

82-
def get_or_create_connection_pool(self, params: dict) -> Pool:
82+
def get_or_create_connection_pool(self, params: dict):
8383
"""
8484
Given a connection parameters and return a new
8585
or cached connection pool for them.
8686
8787
Reimplement this method if you want distinct
8888
connection pool instance caching behavior.
89+
90+
:returns: the connection pool
8991
"""
9092
key: str = params["url"]
9193
if key not in self._pools:

django_valkey/cluster_cache/pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django_valkey.base_pool import BaseConnectionFactory
88

99

10-
class ClusterConnectionFactory(BaseConnectionFactory[ValkeyCluster, ConnectionPool]):
10+
class ClusterConnectionFactory(BaseConnectionFactory[ValkeyCluster]):
1111
path_pool_cls = "valkey.connection.ConnectionPool"
1212
path_base_cls = "valkey.cluster.ValkeyCluster"
1313

django_valkey/pool.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1+
from typing import TYPE_CHECKING
12
from urllib.parse import parse_qs, urlparse
23

34
from django.conf import settings
45
from django.core.exceptions import ImproperlyConfigured
56
from django.utils.module_loading import import_string
67
from valkey import Valkey
7-
from valkey.connection import ConnectionPool, DefaultParser
8+
from valkey.connection import DefaultParser
89
from valkey.sentinel import Sentinel
910
from valkey._parsers.url_parser import to_bool
1011

1112
from django_valkey.base_pool import BaseConnectionFactory
1213
from django_valkey.typings import DefaultParserT
1314

14-
from django_valkey.async_cache.pool import (
15-
AsyncConnectionFactory,
16-
AsyncSentinelConnectionFactory,
17-
)
15+
if TYPE_CHECKING:
16+
from django_valkey.async_cache.pool import (
17+
AsyncConnectionFactory,
18+
AsyncSentinelConnectionFactory,
19+
)
20+
from valkey.connection import ConnectionPool
21+
from valkey.sentinel import SentinelConnectionPool
1822

1923

20-
class ConnectionFactory(BaseConnectionFactory[Valkey, ConnectionPool]):
24+
class ConnectionFactory(BaseConnectionFactory[Valkey]):
2125
path_pool_cls = "valkey.connection.ConnectionPool"
2226
path_base_cls = "valkey.client.Valkey"
2327

@@ -40,11 +44,13 @@ def connect(self, url: str) -> Valkey:
4044
return self.get_connection(params)
4145

4246
def get_connection(self, params: dict) -> Valkey:
43-
pool = self.get_or_create_connection_pool(params)
47+
pool: ConnectionPool = self.get_or_create_connection_pool(params)
4448
return self.base_client_cls(connection_pool=pool, **self.base_client_cls_kwargs)
4549

4650

4751
class SentinelConnectionFactory(ConnectionFactory):
52+
path_pool_cls = "valkey.sentinel.SentinelConnectionPool"
53+
4854
def __init__(self, options: dict):
4955
# allow overriding the default SentinelConnectionPool class
5056
options.setdefault(
@@ -68,7 +74,7 @@ def __init__(self, options: dict):
6874
**connection_kwargs,
6975
)
7076

71-
def get_connection_pool(self, params: dict) -> ConnectionPool:
77+
def get_connection_pool(self, params: dict) -> "SentinelConnectionPool":
7278
"""
7379
Given a connection parameters, return a new sentinel connection pool
7480
for them.
@@ -79,7 +85,7 @@ def get_connection_pool(self, params: dict) -> ConnectionPool:
7985
# SentinelConnectionPool constructor since will be called by from_url
8086
cp_params = params
8187
cp_params.update(service_name=url.hostname, sentinel_manager=self._sentinel)
82-
pool = super().get_connection_pool(cp_params)
88+
pool: SentinelConnectionPool = super().get_connection_pool(cp_params)
8389

8490
# convert "is_master" to a boolean if set on the URL, otherwise if not
8591
# provided it defaults to True.
@@ -92,12 +98,7 @@ def get_connection_pool(self, params: dict) -> ConnectionPool:
9298

9399
def get_connection_factory(
94100
path: str | None = None, options: dict | None = None
95-
) -> (
96-
ConnectionFactory
97-
| SentinelConnectionFactory
98-
| AsyncConnectionFactory
99-
| AsyncSentinelConnectionFactory
100-
):
101+
) -> "ConnectionFactory | SentinelConnectionFactory | AsyncConnectionFactory | AsyncSentinelConnectionFactory":
101102

102103
path = getattr(settings, "DJANGO_VALKEY_CONNECTION_FACTORY", path)
103104
if options:

0 commit comments

Comments
 (0)