diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e10dbf59..41098087 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,8 @@ repos: hooks: - id: mypy name: mypy - entry: python3 -m mypy + entry: python3 -m mypy cashews exclude: (^tests/|^perf/|^examples/|setup.py) + pass_filenames: false language: system types: [python] diff --git a/cashews/backends/diskcache.py b/cashews/backends/diskcache.py index f15b1a93..2bb916cc 100644 --- a/cashews/backends/diskcache.py +++ b/cashews/backends/diskcache.py @@ -111,7 +111,7 @@ async def exists(self, key: Key) -> bool: def _exists(self, key: Key) -> bool: return key in self._cache - async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Value]: # type: ignore + async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Value]: if not self._sharded: for key in await self._run_in_executor(self._scan, pattern): yield key diff --git a/cashews/backends/memory.py b/cashews/backends/memory.py index 80779faf..382de7c1 100644 --- a/cashews/backends/memory.py +++ b/cashews/backends/memory.py @@ -105,7 +105,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None value = await self._serializer.encode(self, key=key, value=value, expire=expire) self._set(key, value, expire) - async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore + async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: pattern = pattern.replace("*", ".*") regexp = re.compile(pattern) for key in dict(self.store): diff --git a/cashews/backends/redis/backend.py b/cashews/backends/redis/backend.py index 1b554f4d..a777a457 100644 --- a/cashews/backends/redis/backend.py +++ b/cashews/backends/redis/backend.py @@ -164,7 +164,7 @@ async def delete(self, key: Key) -> bool: async def exists(self, key: Key) -> bool: return bool(await self._client.exists(key)) - async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore + async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: cursor = 0 while True: cursor, keys = await self._client.scan(cursor, match=pattern, count=batch_size) @@ -193,7 +193,7 @@ async def delete_match(self, pattern: str): await self._client.unlink(*keys) await self._call_on_remove_callbacks(*[key.decode() for key in keys]) - async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]: # type: ignore + async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]: cursor = 0 while True: cursor, keys = await self._client.scan(cursor, match=pattern, count=batch_size) diff --git a/cashews/backends/redis/client_side.py b/cashews/backends/redis/client_side.py index da6636eb..d452078c 100644 --- a/cashews/backends/redis/client_side.py +++ b/cashews/backends/redis/client_side.py @@ -202,7 +202,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None expire=expire, ) - async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore + async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: async for key in super().scan(self._add_prefix(pattern), batch_size=batch_size): yield self._remove_prefix(key) @@ -227,7 +227,7 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu await self._local_cache.set(key, _empty_in_redis) return tuple(missed.get(key, value) for key, value in values.items()) - async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]: # type: ignore + async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[tuple[Key, Value]]: cursor = 0 while True: cursor, keys = await self._client.scan(cursor, match=self._add_prefix(pattern), count=batch_size) diff --git a/cashews/contrib/_starlette.py b/cashews/contrib/_starlette.py index 4a196eb3..547923c7 100644 --- a/cashews/contrib/_starlette.py +++ b/cashews/contrib/_starlette.py @@ -18,12 +18,12 @@ async def encode_streaming_response( async def decode_streaming_response(value: bytes, backend: Backend, key: str, **kwargs) -> StreamingResponse: if not await backend.get(f"{key}:done"): raise DecodeError() - status_code, headers = value.split(b":") + status_code, headers = value.split(b":", maxsplit=1) raw_headers = [] - for header in headers.split(b";"): + for header in headers.split(b";", maxsplit=1): if not header: continue - header_name, header_value = header.split(b"=") + header_name, header_value = header.split(b"=", maxsplit=1) raw_headers.append((header_name, header_value)) content = get_iterator(backend, key) diff --git a/cashews/contrib/fastapi.py b/cashews/contrib/fastapi.py index 4f7b69f6..c1b658ec 100644 --- a/cashews/contrib/fastapi.py +++ b/cashews/contrib/fastapi.py @@ -119,7 +119,7 @@ def _get_max_age(cache_control_value: str) -> int | None: for cc_value_item in cache_control_value.split(","): cc_value_item = cc_value_item.strip() try: - key, value = cc_value_item.split("=") + key, value = cc_value_item.split("=", maxsplit=1) if key == _MAX_AGE[:-1]: return int(value) except (ValueError, TypeError): @@ -192,7 +192,7 @@ def _get_etag(cached_data: Any) -> str: def _is_early_cache(data: Any) -> bool: - return isinstance(data, list) and data and isinstance(data[0], datetime) + return bool(isinstance(data, list) and data and isinstance(data[0], datetime)) class CacheDeleteMiddleware(BaseHTTPMiddleware): diff --git a/cashews/formatter.py b/cashews/formatter.py index 7a4c207b..90913692 100644 --- a/cashews/formatter.py +++ b/cashews/formatter.py @@ -139,7 +139,7 @@ def format_field(self, value: Any, format_spec: str) -> TemplateValue: def parse_format_spec(format_spec: str): if not format_spec or "(" not in format_spec: return format_spec, () - format_spec, args = format_spec.split("(", 1) + format_spec, args = format_spec.split("(", maxsplit=1) return format_spec, args.replace(")", "").split(",") def vformat(self, format_string, args, kwargs): @@ -154,7 +154,7 @@ def vformat(self, format_string, args, kwargs): @default_formatter.register("get", preformat=False) def _get(value: Any, key: str) -> TemplateValue: - return value.get(key) + return str(value.get(key)) @default_formatter.register("len") @@ -164,7 +164,7 @@ def _len(value: TemplateValue): @default_formatter.register("jwt") def _jwt_func(jwt: TemplateValue, key: str) -> TemplateValue: - _, payload, _ = jwt.split(".", 2) + _, payload, _ = jwt.split(".", maxsplit=2) payload_dict = json.loads(base64.b64decode(payload)) return payload_dict.get(key) diff --git a/cashews/serialize.py b/cashews/serialize.py index f04ff5fe..35284e15 100644 --- a/cashews/serialize.py +++ b/cashews/serialize.py @@ -48,7 +48,7 @@ def sign(self, key: Key, value: bytes) -> bytes: def check_sign(self, key: Key, value: bytes) -> bytes: try: - sign, value = value.split(b"_", 1) + sign, value = value.split(b"_", maxsplit=1) except ValueError as exc: raise SignIsMissingError(f"key: {key}") from exc @@ -65,7 +65,7 @@ def _gen_sign(self, key: Key, value: bytes, digestmod: bytes) -> bytes: def _get_sign_and_digestmod(self, sign: bytes) -> tuple[bytes, bytes]: digestmod = self._digestmod if b":" in sign: - digestmod, sign = sign.split(b":") + digestmod, sign = sign.split(b":", maxsplit=1) if digestmod not in self._digestmods: raise UnSecureDataError() return sign, digestmod @@ -145,7 +145,7 @@ def _decode(self, value: bytes) -> Value: async def _custom_decode(self, backend: Backend, key: Key, value: bytes, default: Value) -> Value: try: - value_type, value = value.split(b":", 1) + value_type, value = value.split(b":", maxsplit=1) except ValueError: return default if value_type not in self._type_mapping: diff --git a/tests/test_intergations/test_fastapi.py b/tests/test_intergations/test_fastapi.py index 8116747f..1f01dc62 100644 --- a/tests/test_intergations/test_fastapi.py +++ b/tests/test_intergations/test_fastapi.py @@ -106,11 +106,12 @@ def iterator(): @app.get("/stream") @cache(ttl="10s", key="stream") async def stream(): - return StreamingResponse(iterator(), status_code=201, headers={"X-Test": "TRUE"}) + return StreamingResponse(iterator(), status_code=201, headers={"X-Test": "TRUE", "X-Header": "Some;:=Value"}) response = client.get("/stream") assert response.status_code == 201 assert response.headers["X-Test"] == "TRUE" + assert response.headers["X-Header"] == "Some;:=Value" assert response.content == b"0123456789" response = client.get("/stream") diff --git a/tests/test_key.py b/tests/test_key.py index f6aaf124..11055b03 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -167,7 +167,7 @@ async def func(user): ... ), ( (), - {"kwarg": "1"}, + {"kwarg": 1}, "{@:get(context_value)}:{kwarg}", "context:1", ),