Skip to content

Commit 1d02b39

Browse files
committed
Work around deficiencies in redis-py typing
redis-py has problematical type annotations - async and sync return types are mixed together, and there is no representation of the fact that the client *might* decode the responses from bytes to str, or not, depending on how it was constructed. See redis/redis-py#3619 Work around this by adding protocols that represent the types of the API as we use it, and cast to the protocols as needed.
1 parent b92e8d9 commit 1d02b39

File tree

9 files changed

+179
-28
lines changed

9 files changed

+179
-28
lines changed

flatpak_indexer/bodhi_query.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
import re
66

7+
from flatpak_indexer.redis_utils import TypedPipeline
8+
79
from .http_utils import HttpConfig
810
from .models import BodhiUpdateModel
911
from .nvr import NVR
@@ -198,8 +200,10 @@ def _query_updates(
198200
_run_query(requests_session, content_type, url, params, save_entities, results)
199201

200202

201-
def _refresh_updates(session: Session, content_type, entities, pipe, rows_per_page=None):
202-
pipe.watch("updates-by-entity:" + content_type)
203+
def _refresh_updates(
204+
session: Session, content_type, entities, pipe: TypedPipeline, *, rows_per_page: int
205+
):
206+
pipe.watch("update-cache:" + content_type)
203207

204208
assert isinstance(session.config, HttpConfig)
205209
requests_session = session.config.get_requests_session()
@@ -263,7 +267,9 @@ def refresh_updates(session: Session, content_type, entities, rows_per_page=10):
263267
)
264268

265269

266-
def _refresh_all_updates(session: Session, content_type, pipe, rows_per_page=10):
270+
def _refresh_all_updates(
271+
session: Session, content_type, pipe: TypedPipeline, *, rows_per_page: int
272+
):
267273
pipe.watch("updates-by-entity:" + content_type)
268274

269275
assert isinstance(session.config, HttpConfig)

flatpak_indexer/datasource/fedora/updater.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from collections import defaultdict
22
from typing import DefaultDict, Dict, List, NamedTuple, Optional, Set
33

4-
import redis
5-
64
from ...bodhi_query import (
75
list_updates,
86
refresh_all_updates,
@@ -62,7 +60,6 @@ def _fix_pull_spec(image: ImageModel, registry_url: str, repo_name: str):
6260

6361

6462
class FedoraUpdater(Updater):
65-
redis_client: "redis.Redis[bytes]"
6663
change_monitor: Optional[FedoraMonitor]
6764

6865
def __init__(self, config: Config):

flatpak_indexer/delta_generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22
from datetime import datetime, timedelta, timezone
3-
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
3+
from typing import DefaultDict, Dict, List, Optional, Set, Tuple
44
import json
55
import logging
66
import time
@@ -209,8 +209,7 @@ def do_work(pubsub):
209209
next_expire = now + self.progress_timeout_seconds
210210
with self.redis_client.pipeline() as pipe:
211211
pipe.watch("tardiff:progress")
212-
pre = cast(redis.Redis, pipe)
213-
stale = pre.zrangebyscore(
212+
stale = pipe.zrangebyscore(
214213
"tardiff:progress", 0, now - self.progress_timeout_seconds
215214
)
216215
if len(stale) > 0:
@@ -227,7 +226,7 @@ def do_work(pubsub):
227226
# progress was modified, immediately try again
228227
return True
229228
else:
230-
oldest: List[Tuple[bytes, float]] = pre.zrange(
229+
oldest: List[Tuple[bytes, float]] = pipe.zrange(
231230
"tardiff:progress", 0, 0, withscores=True
232231
)
233232
if len(oldest) > 0:

flatpak_indexer/differ.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datetime import datetime
2-
from typing import cast
32
import logging
43
import os
54
import subprocess
@@ -42,12 +41,11 @@ def _wait_for_task(self, pubsub):
4241
def _get_task(self):
4342
with self.redis_client.pipeline() as pipe:
4443
pipe.watch("tardiff:pending")
45-
pre = cast("redis.Redis[bytes]", pipe)
46-
task_raw = pre.srandmember("tardiff:pending")
44+
task_raw = pipe.srandmember("tardiff:pending")
4745
if task_raw is None:
4846
return None
4947

50-
task = cast(bytes, task_raw).decode("utf-8")
48+
task = task_raw.decode("utf-8")
5149

5250
pipe.multi()
5351
pipe.srem("tardiff:pending", task)

flatpak_indexer/fedora_monitor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Optional, Set, Tuple, cast
2+
from typing import Optional, Set, Tuple
33
import json
44
import logging
55
import os
@@ -12,9 +12,7 @@
1212
import pika.credentials
1313
import pika.exceptions
1414

15-
import redis
16-
17-
from .redis_utils import RedisConfig, get_redis_client
15+
from .redis_utils import RedisConfig, TypedPipeline, get_redis_client
1816

1917
logger = logging.getLogger(__name__)
2018

@@ -180,17 +178,14 @@ def _maybe_reraise_failure(self, msg):
180178
if self.failure:
181179
raise RuntimeError(msg) from self.failure
182180

183-
def _do_add_to_log(
184-
self, new_queue_name, update_id, distgit_path, pipe: "redis.client.Pipeline[bytes]"
185-
):
181+
def _do_add_to_log(self, new_queue_name, update_id, distgit_path, pipe: TypedPipeline):
186182
pipe.watch(KEY_SERIAL)
187183
if self.watch_bodhi_updates and (new_queue_name or update_id):
188184
pipe.watch(KEY_UPDATE_CHANGELOG)
189185
if self.watch_distgit_changes and (new_queue_name or distgit_path):
190186
pipe.watch(KEY_DISTGIT_CHANGELOG)
191187

192-
pre = cast("redis.Redis[bytes]", pipe)
193-
serial = 1 + int(pre.get("updatequeue:serial") or 0)
188+
serial = 1 + int(pipe.get("updatequeue:serial") or 0)
194189

195190
pipe.multi()
196191
pipe.set(KEY_SERIAL, serial)

flatpak_indexer/redis_utils.py

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Optional
1+
from typing import Callable, ContextManager, Iterable, List, Literal, Optional, Protocol, overload
22
from urllib.parse import quote, urlparse, urlunparse
33
import logging
44
import time
55

6+
from redis.typing import EncodableT, ExpiryT, FieldT, KeysT, KeyT, ZScoreBoundT
7+
import redis.client
8+
69
import redis
710

811
from .base_config import BaseConfig
@@ -15,7 +18,155 @@ class RedisConfig(BaseConfig):
1518
redis_password: Optional[str] = None
1619

1720

18-
def get_redis_client(config: RedisConfig) -> "redis.Redis[bytes]":
21+
class TypedRedis(Protocol):
22+
"""
23+
A typed subset of redis.Redis methods used by flatpak-indexer.
24+
25+
The redis-py type annotations are a mess: async and sync return types
26+
are mixed together, and there is no representation of the fact that
27+
the client *might* decode the responses from bytes to str, or not,
28+
depending on how it was constructed.
29+
30+
This Protocol represents the methods we actually use, with more precise
31+
types - we simply cast the redis.Redis instances we get to this type.
32+
"""
33+
34+
def delete(self, *names: KeyT) -> int: ...
35+
36+
def execute(self) -> List[object]: ...
37+
38+
def exists(self, name: KeyT) -> bool: ...
39+
40+
def hmget(self, name: KeyT, keys: List[KeyT], *args) -> list[bytes | None]: ...
41+
42+
def hget(self, name: KeyT, key: KeyT) -> bytes | None: ...
43+
44+
def get(self, name: KeyT) -> Optional[bytes]: ...
45+
46+
@overload
47+
def hset(self, name: KeyT, key: KeyT, value: EncodableT) -> int: ...
48+
49+
@overload
50+
def hset(self, name: KeyT, mapping: dict[KeyT, EncodableT]) -> int: ...
51+
52+
def mget(self, keys: KeysT, *args) -> list[bytes | None]: ...
53+
54+
def multi(self): ...
55+
56+
def pipeline(self, *args, **kwargs) -> "TypedPipeline": ...
57+
58+
def publish(self, channel: KeyT, message: EncodableT) -> int: ...
59+
60+
def sadd(self, name: KeyT, *values: FieldT) -> int: ...
61+
62+
def scan_iter(self, match: Optional[EncodableT] = None) -> Iterable[str]: ...
63+
64+
def scard(self, name: KeyT) -> int: ...
65+
66+
def set(self, name: KeyT, value: EncodableT) -> bool: ...
67+
68+
def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> bool: ...
69+
70+
@overload
71+
def srandmember(self, name: KeyT, number: None = None) -> Optional[bytes]: ...
72+
73+
@overload
74+
def srandmember(self, name: KeyT, number: int) -> List[bytes]: ...
75+
76+
def srem(self, name: KeyT, *values: FieldT) -> int: ...
77+
78+
def transaction(self, func: Callable[["TypedPipeline"], None]): ...
79+
80+
def watch(self, keys: KeysT): ...
81+
82+
def zadd(self, name: KeyT, mapping: dict[EncodableT, float], xx: bool = False) -> int: ...
83+
84+
def zcard(self, name: KeyT) -> int: ...
85+
86+
@overload
87+
def zrange(
88+
self,
89+
name: KeyT,
90+
start: int,
91+
end: int,
92+
*,
93+
desc: bool = False,
94+
withscores: Literal[False] = False,
95+
score_cast_func: type | Callable = float,
96+
byscore: bool = False,
97+
bylex: bool = False,
98+
offset: Optional[int] = None,
99+
num: Optional[int] = None,
100+
) -> list[bytes]: ...
101+
102+
@overload
103+
def zrange(
104+
self,
105+
name: KeyT,
106+
start: int,
107+
end: int,
108+
*,
109+
desc: bool = False,
110+
withscores: Literal[True],
111+
score_cast_func: type | Callable = float,
112+
byscore: bool = False,
113+
bylex: bool = False,
114+
offset: Optional[int] = None,
115+
num: Optional[int] = None,
116+
) -> list[tuple[bytes, float]]: ...
117+
118+
def zrangebylex(
119+
self,
120+
name: KeyT,
121+
min: EncodableT,
122+
max: EncodableT,
123+
start: Optional[int] = None,
124+
num: Optional[int] = None,
125+
) -> list[bytes]: ...
126+
127+
@overload
128+
def zrangebyscore(
129+
self,
130+
name: KeyT,
131+
min: ZScoreBoundT,
132+
max: ZScoreBoundT,
133+
*,
134+
start: Optional[int] = None,
135+
num: Optional[int] = None,
136+
withscores: Literal[False] = False,
137+
score_cast_func: type | Callable = float,
138+
) -> list[bytes]: ...
139+
140+
@overload
141+
def zrangebyscore(
142+
self,
143+
name: KeyT,
144+
min: ZScoreBoundT,
145+
max: ZScoreBoundT,
146+
*,
147+
start: Optional[int] = None,
148+
num: Optional[int] = None,
149+
withscores: Literal[True],
150+
score_cast_func: type | Callable = float,
151+
) -> list[tuple[bytes, float]]: ...
152+
153+
def zrem(self, name: KeyT, *values: EncodableT) -> int: ...
154+
155+
def zremrangebyscore(
156+
self,
157+
name: KeyT,
158+
min: ZScoreBoundT,
159+
max: ZScoreBoundT,
160+
) -> int: ...
161+
162+
def zscore(self, name: KeyT, value: EncodableT) -> float | None: ...
163+
164+
165+
class TypedPipeline(TypedRedis, ContextManager["TypedPipeline"]):
166+
pass
167+
168+
169+
def get_redis_client(config: RedisConfig) -> TypedRedis:
19170
url = config.redis_url
20171

21172
# redis.Redis.from_url() doesn't support passing the password separately

tests/test_delta_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime
2+
from typing import cast
23
import hashlib
34
import json
45
import logging
@@ -11,6 +12,7 @@
1112
from flatpak_indexer.cleaner import Cleaner
1213
from flatpak_indexer.delta_generator import DeltaGenerator
1314
from flatpak_indexer.models import RepositoryModel, TardiffResultModel, TardiffSpecModel
15+
from flatpak_indexer.redis_utils import TypedRedis
1416
from flatpak_indexer.test.redis import mock_redis
1517
from flatpak_indexer.utils import path_for_digest
1618
import redis
@@ -375,7 +377,7 @@ def test_delta_generator_expire(tmp_path):
375377
generator.generate()
376378

377379
# Expire the deltas we just generated
378-
redis_client = redis.Redis.from_url(config.redis_url)
380+
redis_client = cast(TypedRedis, redis.Redis.from_url(config.redis_url))
379381
all_tardiffs_raw = redis_client.zrangebyscore("tardiff:active", 0, float("inf"))
380382
all_tardiffs = (k.decode("utf-8") for k in all_tardiffs_raw)
381383
for result_raw in redis_client.mget(*(f"tardiff:result:{k}" for k in all_tardiffs)):

tests/test_differ.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import cast
12
import os
23
import threading
34
import time
@@ -7,6 +8,7 @@
78

89
from flatpak_indexer.differ import Differ
910
from flatpak_indexer.models import TardiffImageModel, TardiffResultModel, TardiffSpecModel
11+
from flatpak_indexer.redis_utils import TypedRedis
1012
from flatpak_indexer.test.redis import mock_redis
1113
import redis
1214

@@ -60,7 +62,7 @@ def queue_task(from_ref, from_diff_id, to_ref, to_diff_id, redis_client=None, sk
6062

6163

6264
def check_success(key, old_layer, new_layer):
63-
redis_client = redis.Redis.from_url("redis://localhost")
65+
redis_client = cast(TypedRedis, redis.Redis.from_url("redis://localhost"))
6466

6567
assert redis_client.scard("tardiff:pending") == 0
6668
assert redis_client.zscore("tardiff:progress", key) is None
@@ -83,7 +85,7 @@ def check_success(key, old_layer, new_layer):
8385

8486

8587
def check_failure(key, status, message):
86-
redis_client = redis.Redis.from_url("redis://localhost")
88+
redis_client = cast(TypedRedis, redis.Redis.from_url("redis://localhost"))
8789

8890
assert redis_client.scard("tardiff:pending") == 0
8991
assert redis_client.zscore("tardiff:progress", key) is None

tests/test_redis_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66

77
import pytest
8+
import redis.client
89

910
from flatpak_indexer.redis_utils import RedisConfig, do_pubsub_work, get_redis_client
1011
from flatpak_indexer.test.redis import mock_redis

0 commit comments

Comments
 (0)