Skip to content

Commit ff45a8d

Browse files
committed
Work around deficiencies in redis-py typing
redis-py has problematical type annotations - async and sync return types are mixed together, and there is no representation of the fact that the client *might* decode the responses from bytes to str, or not, depending on how it was constructed. See redis/redis-py#3619 Work around this by adding protocols that represent the types of the API as we use it, and cast to the protocols as needed.
1 parent 5d9d564 commit ff45a8d

File tree

9 files changed

+192
-27
lines changed

9 files changed

+192
-27
lines changed

flatpak_indexer/bodhi_query.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
import re
66

7+
from flatpak_indexer.redis_utils import TypedPipeline
8+
79
from .http_utils import HttpConfig
810
from .models import BodhiUpdateModel
911
from .nvr import NVR
@@ -198,7 +200,9 @@ def _query_updates(
198200
_run_query(requests_session, content_type, url, params, save_entities, results)
199201

200202

201-
def _refresh_updates(session: Session, content_type, entities, pipe, rows_per_page=None):
203+
def _refresh_updates(
204+
session: Session, content_type, entities, pipe: TypedPipeline, *, rows_per_page: int
205+
):
202206
pipe.watch("updates-by-entity:" + content_type)
203207

204208
assert isinstance(session.config, HttpConfig)
@@ -263,7 +267,9 @@ def refresh_updates(session: Session, content_type, entities, rows_per_page=10):
263267
)
264268

265269

266-
def _refresh_all_updates(session: Session, content_type, pipe, rows_per_page=10):
270+
def _refresh_all_updates(
271+
session: Session, content_type, pipe: TypedPipeline, *, rows_per_page: int
272+
):
267273
pipe.watch("updates-by-entity:" + content_type)
268274

269275
assert isinstance(session.config, HttpConfig)

flatpak_indexer/datasource/fedora/updater.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from collections import defaultdict
22
from typing import DefaultDict, Dict, List, NamedTuple, Optional, Set
33

4-
import redis
5-
64
from ...bodhi_query import (
75
list_updates,
86
refresh_all_updates,
@@ -62,7 +60,6 @@ def _fix_pull_spec(image: ImageModel, registry_url: str, repo_name: str):
6260

6361

6462
class FedoraUpdater(Updater):
65-
redis_client: "redis.Redis[bytes]"
6663
change_monitor: Optional[FedoraMonitor]
6764

6865
def __init__(self, config: Config):

flatpak_indexer/delta_generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22
from datetime import datetime, timedelta, timezone
3-
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
3+
from typing import DefaultDict, Dict, List, Optional, Set, Tuple
44
import json
55
import logging
66
import time
@@ -209,8 +209,7 @@ def do_work(pubsub):
209209
next_expire = now + self.progress_timeout_seconds
210210
with self.redis_client.pipeline() as pipe:
211211
pipe.watch("tardiff:progress")
212-
pre = cast(redis.Redis, pipe)
213-
stale = pre.zrangebyscore(
212+
stale = pipe.zrangebyscore(
214213
"tardiff:progress", 0, now - self.progress_timeout_seconds
215214
)
216215
if len(stale) > 0:
@@ -227,7 +226,7 @@ def do_work(pubsub):
227226
# progress was modified, immediately try again
228227
return True
229228
else:
230-
oldest: List[Tuple[bytes, float]] = pre.zrange(
229+
oldest: List[Tuple[bytes, float]] = pipe.zrange(
231230
"tardiff:progress", 0, 0, withscores=True
232231
)
233232
if len(oldest) > 0:

flatpak_indexer/differ.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datetime import datetime
2-
from typing import cast
32
import logging
43
import os
54
import subprocess
@@ -42,12 +41,11 @@ def _wait_for_task(self, pubsub):
4241
def _get_task(self):
4342
with self.redis_client.pipeline() as pipe:
4443
pipe.watch("tardiff:pending")
45-
pre = cast("redis.Redis[bytes]", pipe)
46-
task_raw = pre.srandmember("tardiff:pending")
44+
task_raw = pipe.srandmember("tardiff:pending")
4745
if task_raw is None:
4846
return None
4947

50-
task = cast(bytes, task_raw).decode("utf-8")
48+
task = task_raw.decode("utf-8")
5149

5250
pipe.multi()
5351
pipe.srem("tardiff:pending", task)

flatpak_indexer/fedora_monitor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Optional, Set, Tuple, cast
2+
from typing import Optional, Set, Tuple
33
import json
44
import logging
55
import os
@@ -12,9 +12,7 @@
1212
import pika.credentials
1313
import pika.exceptions
1414

15-
import redis
16-
17-
from .redis_utils import RedisConfig, get_redis_client
15+
from .redis_utils import RedisConfig, TypedPipeline, get_redis_client
1816

1917
logger = logging.getLogger(__name__)
2018

@@ -180,17 +178,14 @@ def _maybe_reraise_failure(self, msg):
180178
if self.failure:
181179
raise RuntimeError(msg) from self.failure
182180

183-
def _do_add_to_log(
184-
self, new_queue_name, update_id, distgit_path, pipe: "redis.client.Pipeline[bytes]"
185-
):
181+
def _do_add_to_log(self, new_queue_name, update_id, distgit_path, pipe: TypedPipeline):
186182
pipe.watch(KEY_SERIAL)
187183
if self.watch_bodhi_updates and (new_queue_name or update_id):
188184
pipe.watch(KEY_UPDATE_CHANGELOG)
189185
if self.watch_distgit_changes and (new_queue_name or distgit_path):
190186
pipe.watch(KEY_DISTGIT_CHANGELOG)
191187

192-
pre = cast("redis.Redis[bytes]", pipe)
193-
serial = 1 + int(pre.get("updatequeue:serial") or 0)
188+
serial = 1 + int(pipe.get("updatequeue:serial") or 0)
194189

195190
pipe.multi()
196191
pipe.set(KEY_SERIAL, serial)

flatpak_indexer/redis_utils.py

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
1-
from typing import Optional
1+
from typing import (
2+
Callable,
3+
ContextManager,
4+
Iterable,
5+
List,
6+
Literal,
7+
Mapping,
8+
Optional,
9+
Protocol,
10+
TypeVar,
11+
overload,
12+
)
213
from urllib.parse import quote, urlparse, urlunparse
314
import logging
415
import time
516

17+
from redis.typing import EncodableT, ExpiryT, FieldT, KeysT, KeyT, ZScoreBoundT
18+
import redis.client
19+
620
import redis
721

822
from .base_config import BaseConfig
@@ -15,7 +29,158 @@ class RedisConfig(BaseConfig):
1529
redis_password: Optional[str] = None
1630

1731

18-
def get_redis_client(config: RedisConfig) -> "redis.Redis[bytes]":
32+
ScoreT = TypeVar("ScoreT")
33+
34+
35+
class TypedRedis(Protocol):
36+
"""
37+
A typed subset of redis.Redis methods used by flatpak-indexer.
38+
39+
The redis-py type annotations are a mess: async and sync return types
40+
are mixed together, and there is no representation of the fact that
41+
the client *might* decode the responses from bytes to str, or not,
42+
depending on how it was constructed.
43+
44+
This Protocol represents the methods we actually use, with more precise
45+
types - we simply cast the redis.Redis instances we get to this type.
46+
"""
47+
48+
def delete(self, *names: KeyT) -> int: ...
49+
50+
def execute(self) -> List[object]: ...
51+
52+
def exists(self, name: KeyT) -> bool: ...
53+
54+
def hmget(self, name: KeyT, keys: KeysT, *args: KeyT) -> list[bytes | None]: ...
55+
56+
def hget(self, name: KeyT, key: KeyT) -> bytes | None: ...
57+
58+
def get(self, name: KeyT) -> Optional[bytes]: ...
59+
60+
@overload
61+
def hset(self, name: KeyT, key: KeyT, value: EncodableT) -> int: ...
62+
63+
@overload
64+
def hset(self, name: KeyT, mapping: Mapping[str, EncodableT]) -> int: ...
65+
66+
def mget(self, keys: KeysT, *args: KeyT) -> list[bytes | None]: ...
67+
68+
def multi(self): ...
69+
70+
def pipeline(self, *args, **kwargs) -> "TypedPipeline": ...
71+
72+
def publish(self, channel: KeyT, message: EncodableT) -> int: ...
73+
74+
def sadd(self, name: KeyT, *values: FieldT) -> int: ...
75+
76+
def scan_iter(self, match: Optional[EncodableT] = None) -> Iterable[str]: ...
77+
78+
def scard(self, name: KeyT) -> int: ...
79+
80+
def set(self, name: KeyT, value: EncodableT) -> bool: ...
81+
82+
def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> bool: ...
83+
84+
@overload
85+
def srandmember(self, name: KeyT, number: None = None) -> Optional[bytes]: ...
86+
87+
@overload
88+
def srandmember(self, name: KeyT, number: int) -> List[bytes]: ...
89+
90+
def srem(self, name: KeyT, *values: FieldT) -> int: ...
91+
92+
def transaction(self, func: Callable[["TypedPipeline"], None]): ...
93+
94+
def watch(self, keys: KeysT): ...
95+
96+
def zadd(self, name: KeyT, mapping: dict[EncodableT, float], xx: bool = False) -> int: ...
97+
98+
def zcard(self, name: KeyT) -> int: ...
99+
100+
@overload
101+
def zrange(
102+
self,
103+
name: KeyT,
104+
start: int,
105+
end: int,
106+
*,
107+
desc: bool = False,
108+
withscores: Literal[False] = False,
109+
score_cast_func: type | Callable = float,
110+
byscore: bool = False,
111+
bylex: bool = False,
112+
offset: Optional[int] = None,
113+
num: Optional[int] = None,
114+
) -> list[bytes]: ...
115+
116+
@overload
117+
def zrange(
118+
self,
119+
name: KeyT,
120+
start: int,
121+
end: int,
122+
*,
123+
desc: bool = False,
124+
withscores: Literal[True],
125+
score_cast_func: type | Callable[[float], ScoreT] = float,
126+
byscore: bool = False,
127+
bylex: bool = False,
128+
offset: Optional[int] = None,
129+
num: Optional[int] = None,
130+
) -> list[tuple[bytes, ScoreT]]: ...
131+
132+
def zrangebylex(
133+
self,
134+
name: KeyT,
135+
min: EncodableT,
136+
max: EncodableT,
137+
start: Optional[int] = None,
138+
num: Optional[int] = None,
139+
) -> list[bytes]: ...
140+
141+
@overload
142+
def zrangebyscore(
143+
self,
144+
name: KeyT,
145+
min: ZScoreBoundT,
146+
max: ZScoreBoundT,
147+
*,
148+
start: Optional[int] = None,
149+
num: Optional[int] = None,
150+
withscores: Literal[False] = False,
151+
score_cast_func: type | Callable = float,
152+
) -> list[bytes]: ...
153+
154+
@overload
155+
def zrangebyscore(
156+
self,
157+
name: KeyT,
158+
min: ZScoreBoundT,
159+
max: ZScoreBoundT,
160+
*,
161+
start: Optional[int] = None,
162+
num: Optional[int] = None,
163+
withscores: Literal[True],
164+
score_cast_func: type | Callable[[float], ScoreT] = float,
165+
) -> list[tuple[bytes, ScoreT]]: ...
166+
167+
def zrem(self, name: KeyT, *values: EncodableT) -> int: ...
168+
169+
def zremrangebyscore(
170+
self,
171+
name: KeyT,
172+
min: ZScoreBoundT,
173+
max: ZScoreBoundT,
174+
) -> int: ...
175+
176+
def zscore(self, name: KeyT, value: EncodableT) -> float | None: ...
177+
178+
179+
class TypedPipeline(TypedRedis, ContextManager["TypedPipeline"]):
180+
pass
181+
182+
183+
def get_redis_client(config: RedisConfig) -> TypedRedis:
19184
url = config.redis_url
20185

21186
# redis.Redis.from_url() doesn't support passing the password separately

tests/test_delta_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime
2+
from typing import cast
23
import hashlib
34
import json
45
import logging
@@ -11,6 +12,7 @@
1112
from flatpak_indexer.cleaner import Cleaner
1213
from flatpak_indexer.delta_generator import DeltaGenerator
1314
from flatpak_indexer.models import RepositoryModel, TardiffResultModel, TardiffSpecModel
15+
from flatpak_indexer.redis_utils import TypedRedis
1416
from flatpak_indexer.test.redis import mock_redis
1517
from flatpak_indexer.utils import path_for_digest
1618
import redis
@@ -375,7 +377,7 @@ def test_delta_generator_expire(tmp_path):
375377
generator.generate()
376378

377379
# Expire the deltas we just generated
378-
redis_client = redis.Redis.from_url(config.redis_url)
380+
redis_client = cast(TypedRedis, redis.Redis.from_url(config.redis_url))
379381
all_tardiffs_raw = redis_client.zrangebyscore("tardiff:active", 0, float("inf"))
380382
all_tardiffs = (k.decode("utf-8") for k in all_tardiffs_raw)
381383
for result_raw in redis_client.mget(*(f"tardiff:result:{k}" for k in all_tardiffs)):

tests/test_differ.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import cast
12
import os
23
import threading
34
import time
@@ -7,6 +8,7 @@
78

89
from flatpak_indexer.differ import Differ
910
from flatpak_indexer.models import TardiffImageModel, TardiffResultModel, TardiffSpecModel
11+
from flatpak_indexer.redis_utils import TypedRedis
1012
from flatpak_indexer.test.redis import mock_redis
1113
import redis
1214

@@ -60,7 +62,7 @@ def queue_task(from_ref, from_diff_id, to_ref, to_diff_id, redis_client=None, sk
6062

6163

6264
def check_success(key, old_layer, new_layer):
63-
redis_client = redis.Redis.from_url("redis://localhost")
65+
redis_client = cast(TypedRedis, redis.Redis.from_url("redis://localhost"))
6466

6567
assert redis_client.scard("tardiff:pending") == 0
6668
assert redis_client.zscore("tardiff:progress", key) is None
@@ -83,7 +85,7 @@ def check_success(key, old_layer, new_layer):
8385

8486

8587
def check_failure(key, status, message):
86-
redis_client = redis.Redis.from_url("redis://localhost")
88+
redis_client = cast(TypedRedis, redis.Redis.from_url("redis://localhost"))
8789

8890
assert redis_client.scard("tardiff:pending") == 0
8991
assert redis_client.zscore("tardiff:progress", key) is None

tests/test_redis_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66

77
import pytest
8+
import redis.client
89

910
from flatpak_indexer.redis_utils import RedisConfig, do_pubsub_work, get_redis_client
1011
from flatpak_indexer.test.redis import mock_redis

0 commit comments

Comments
 (0)