Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ repos:
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy
entry: python3 -m mypy cashews
exclude: (^tests/|^perf/|^examples/|setup.py)
pass_filenames: false
language: system
types: [python]
2 changes: 1 addition & 1 deletion cashews/backends/diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def exists(self, key: Key) -> bool:
def _exists(self, key: Key) -> bool:
return key in self._cache

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Value]: # type: ignore
async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Value]:
if not self._sharded:
for key in await self._run_in_executor(self._scan, pattern):
yield key
Expand Down
2 changes: 1 addition & 1 deletion cashews/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore
async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]:
pattern = pattern.replace("*", ".*")
regexp = re.compile(pattern)
for key in dict(self.store):
Expand Down
4 changes: 2 additions & 2 deletions cashews/backends/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def delete(self, key: Key) -> bool:
async def exists(self, key: Key) -> bool:
return bool(await self._client.exists(key))

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore
async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]:
cursor = 0
while True:
cursor, keys = await self._client.scan(cursor, match=pattern, count=batch_size)
Expand Down Expand Up @@ -193,7 +193,7 @@ async def delete_match(self, pattern: str):
await self._client.unlink(*keys)
await self._call_on_remove_callbacks(*[key.decode() for key in keys])

async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]: # type: ignore
async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]:
cursor = 0
while True:
cursor, keys = await self._client.scan(cursor, match=pattern, count=batch_size)
Expand Down
4 changes: 2 additions & 2 deletions cashews/backends/redis/client_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None
expire=expire,
)

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore
async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]:
async for key in super().scan(self._add_prefix(pattern), batch_size=batch_size):
yield self._remove_prefix(key)

Expand All @@ -227,7 +227,7 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu
await self._local_cache.set(key, _empty_in_redis)
return tuple(missed.get(key, value) for key, value in values.items())

async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]: # type: ignore
async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]:
cursor = 0
while True:
cursor, keys = await self._client.scan(cursor, match=self._add_prefix(pattern), count=batch_size)
Expand Down
6 changes: 3 additions & 3 deletions cashews/contrib/_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ async def encode_streaming_response(
async def decode_streaming_response(value: bytes, backend: Backend, key: str, **kwargs) -> StreamingResponse:
if not await backend.get(f"{key}:done"):
raise DecodeError()
status_code, headers = value.split(b":")
status_code, headers = value.split(b":", maxsplit=1)
raw_headers = []
for header in headers.split(b";"):
for header in headers.split(b";", maxsplit=1):
if not header:
continue
header_name, header_value = header.split(b"=")
header_name, header_value = header.split(b"=", maxsplit=1)
raw_headers.append((header_name, header_value))

content = get_iterator(backend, key)
Expand Down
4 changes: 2 additions & 2 deletions cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_max_age(cache_control_value: str) -> int | None:
for cc_value_item in cache_control_value.split(","):
cc_value_item = cc_value_item.strip()
try:
key, value = cc_value_item.split("=")
key, value = cc_value_item.split("=", maxsplit=1)
if key == _MAX_AGE[:-1]:
return int(value)
except (ValueError, TypeError):
Expand Down Expand Up @@ -192,7 +192,7 @@ def _get_etag(cached_data: Any) -> str:


def _is_early_cache(data: Any) -> bool:
return isinstance(data, list) and data and isinstance(data[0], datetime)
return bool(isinstance(data, list) and data and isinstance(data[0], datetime))


class CacheDeleteMiddleware(BaseHTTPMiddleware):
Expand Down
6 changes: 3 additions & 3 deletions cashews/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def format_field(self, value: Any, format_spec: str) -> TemplateValue:
def parse_format_spec(format_spec: str):
if not format_spec or "(" not in format_spec:
return format_spec, ()
format_spec, args = format_spec.split("(", 1)
format_spec, args = format_spec.split("(", maxsplit=1)
return format_spec, args.replace(")", "").split(",")

def vformat(self, format_string, args, kwargs):
Expand All @@ -154,7 +154,7 @@ def vformat(self, format_string, args, kwargs):

@default_formatter.register("get", preformat=False)
def _get(value: Any, key: str) -> TemplateValue:
return value.get(key)
return str(value.get(key))


@default_formatter.register("len")
Expand All @@ -164,7 +164,7 @@ def _len(value: TemplateValue):

@default_formatter.register("jwt")
def _jwt_func(jwt: TemplateValue, key: str) -> TemplateValue:
_, payload, _ = jwt.split(".", 2)
_, payload, _ = jwt.split(".", maxsplit=2)
payload_dict = json.loads(base64.b64decode(payload))
return payload_dict.get(key)

Expand Down
6 changes: 3 additions & 3 deletions cashews/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def sign(self, key: Key, value: bytes) -> bytes:

def check_sign(self, key: Key, value: bytes) -> bytes:
try:
sign, value = value.split(b"_", 1)
sign, value = value.split(b"_", maxsplit=1)
except ValueError as exc:
raise SignIsMissingError(f"key: {key}") from exc

Expand All @@ -65,7 +65,7 @@ def _gen_sign(self, key: Key, value: bytes, digestmod: bytes) -> bytes:
def _get_sign_and_digestmod(self, sign: bytes) -> tuple[bytes, bytes]:
digestmod = self._digestmod
if b":" in sign:
digestmod, sign = sign.split(b":")
digestmod, sign = sign.split(b":", maxsplit=1)
if digestmod not in self._digestmods:
raise UnSecureDataError()
return sign, digestmod
Expand Down Expand Up @@ -145,7 +145,7 @@ def _decode(self, value: bytes) -> Value:

async def _custom_decode(self, backend: Backend, key: Key, value: bytes, default: Value) -> Value:
try:
value_type, value = value.split(b":", 1)
value_type, value = value.split(b":", maxsplit=1)
except ValueError:
return default
if value_type not in self._type_mapping:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_intergations/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ def iterator():
@app.get("/stream")
@cache(ttl="10s", key="stream")
async def stream():
return StreamingResponse(iterator(), status_code=201, headers={"X-Test": "TRUE"})
return StreamingResponse(iterator(), status_code=201, headers={"X-Test": "TRUE", "X-Header": "Some;:=Value"})

response = client.get("/stream")
assert response.status_code == 201
assert response.headers["X-Test"] == "TRUE"
assert response.headers["X-Header"] == "Some;:=Value"
assert response.content == b"0123456789"

response = client.get("/stream")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def func(user): ...
),
(
(),
{"kwarg": "1"},
{"kwarg": 1},
"{@:get(context_value)}:{kwarg}",
"context:1",
),
Expand Down
Loading