Skip to content

Commit 65b0bd2

Browse files
authored
feat: zstd middleware (#4286)
1 parent 0bcd53c commit 65b0bd2

File tree

15 files changed

+3037
-2833
lines changed

15 files changed

+3037
-2833
lines changed

litestar/config/compression.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CompressionConfig:
2121
using the ``compression_config`` key.
2222
"""
2323

24-
backend: Literal["gzip", "brotli"] | str
24+
backend: Literal["gzip", "brotli", "zstd"] | str
2525
"""The backend to use.
2626
2727
If the value given is `gzip` or `brotli`, then the builtin gzip and brotli compression is used.
@@ -30,6 +30,8 @@ class CompressionConfig:
3030
"""Minimum response size (bytes) to enable compression, affects all backends."""
3131
gzip_compress_level: int = field(default=9)
3232
"""Range ``[0-9]``, see :doc:`python:library/gzip`."""
33+
zstd_compress_level: int = field(default=3)
34+
"""Range `[0-22]`, see `zstandard <https://pypi.org/project/zstandard/>`"""
3335
brotli_quality: int = field(default=5)
3436
"""Range ``[0-11]``, Controls the compression-speed vs compression-density tradeoff.
3537
@@ -59,6 +61,8 @@ class CompressionConfig:
5961
"""The compression facade to use for the actual compression."""
6062
backend_config: Any = None
6163
"""Configuration specific to the backend."""
64+
zstd_gzip_fallback: bool = True
65+
"""Use GZIP as a fallback if Zstd is not supported by the client."""
6266
gzip_fallback: bool = True
6367
"""Use GZIP as a fallback if the provided backend is not supported by the client."""
6468

@@ -81,3 +85,13 @@ def __post_init__(self) -> None:
8185

8286
self.gzip_fallback = self.brotli_gzip_fallback
8387
self.compression_facade = BrotliCompression
88+
elif self.backend == "zstd":
89+
if self.zstd_compress_level < 1 or self.zstd_compress_level > 22:
90+
raise ImproperlyConfiguredException(
91+
f"zstd_compress_level must be between 1 and 22, given: {self.zstd_compress_level}"
92+
)
93+
94+
from litestar.middleware.compression.zstd_facade import ZstdCompression
95+
96+
self.gzip_fallback = self.zstd_gzip_fallback
97+
self.compression_facade = ZstdCompression

litestar/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class CompressionEncoding(str, Enum):
7777

7878
GZIP = "gzip"
7979
BROTLI = "br"
80+
ZSTD = "zstd"
8081

8182

8283
class ASGIExtension(str, Enum):

litestar/middleware/compression/brotli_facade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
lgblock=config.brotli_lgblock,
4444
)
4545

46-
def write(self, body: bytes) -> None:
46+
def write(self, body: bytes | bytearray, final: bool = False) -> None:
4747
self.buffer.write(self.compressor.process(body))
4848
self.buffer.write(self.compressor.flush())
4949

litestar/middleware/compression/facade.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def __init__(
2929
"""
3030
...
3131

32-
def write(self, body: bytes) -> None:
33-
"""Write compressed bytes.
32+
def write(self, body: bytes | bytearray, final: bool = False) -> None:
33+
"""Write compressed bytes to the buffer.
3434
3535
Args:
36-
body: Message body to process
36+
body: The message body to process. Can be `bytes` or `bytearray`.
37+
final: Indicates whether this is the last chunk of data. If True,
38+
the compressor may flush any remaining internal buffers.
3739
3840
Returns:
3941
None

litestar/middleware/compression/gzip_facade.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ def __init__(
2424
self.compression_encoding = compression_encoding
2525
self.compressor = GzipFile(mode="wb", fileobj=buffer, compresslevel=config.gzip_compress_level)
2626

27-
def write(self, body: bytes) -> None:
28-
self.compressor.write(body)
27+
def write(self, body: bytes | bytearray, final: bool = False) -> None:
28+
data = bytes(body)
29+
self.compressor.write(data)
2930
self.compressor.flush()
3031

3132
def close(self) -> None:

litestar/middleware/compression/middleware.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from litestar.enums import CompressionEncoding, ScopeType
88
from litestar.middleware.base import AbstractMiddleware
99
from litestar.middleware.compression.gzip_facade import GzipCompression
10+
from litestar.middleware.compression.zstd_facade import ZstdCompression
1011
from litestar.utils.empty import value_or_default
1112
from litestar.utils.scope.state import ScopeState
1213

@@ -82,10 +83,22 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8283

8384
await self.app(scope, receive, send)
8485

86+
def get_facade_cls(
87+
self,
88+
compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP, CompressionEncoding.ZSTD]
89+
| str,
90+
) -> type[CompressionFacade]:
91+
if compression_encoding == CompressionEncoding.GZIP:
92+
return GzipCompression
93+
if compression_encoding == CompressionEncoding.ZSTD:
94+
return ZstdCompression
95+
return self.config.compression_facade
96+
8597
def create_compression_send_wrapper(
8698
self,
8799
send: Send,
88-
compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP] | str,
100+
compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP, CompressionEncoding.ZSTD]
101+
| str,
89102
scope: Scope,
90103
) -> Send:
91104
"""Wrap ``send`` to handle brotli compression.
@@ -100,15 +113,10 @@ def create_compression_send_wrapper(
100113
"""
101114
bytes_buffer = BytesIO()
102115

103-
facade: CompressionFacade
104116
# We can't use `self.config.compression_facade` directly if the compression is `gzip` since
105117
# it may be being used as a fallback.
106-
if compression_encoding == CompressionEncoding.GZIP:
107-
facade = GzipCompression(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config)
108-
else:
109-
facade = self.config.compression_facade(
110-
buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config
111-
)
118+
facade_cls: type[CompressionFacade] = self.get_facade_cls(compression_encoding)
119+
facade = facade_cls(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config)
112120

113121
initial_message: HTTPResponseStartEvent | None = None
114122
started = False
@@ -151,7 +159,7 @@ async def send_wrapper(message: Message) -> None:
151159
del headers["Content-Length"]
152160
connection_state.response_compressed = True
153161

154-
facade.write(body)
162+
facade.write(body, final=not more_body)
155163

156164
message["body"] = bytes_buffer.getvalue()
157165
bytes_buffer.seek(0)
@@ -160,7 +168,7 @@ async def send_wrapper(message: Message) -> None:
160168
await send(message)
161169

162170
elif len(body) >= self.config.minimum_size:
163-
facade.write(body)
171+
facade.write(body, final=not more_body)
164172
facade.close()
165173
body = bytes_buffer.getvalue()
166174

@@ -180,7 +188,7 @@ async def send_wrapper(message: Message) -> None:
180188
await send(message)
181189

182190
else:
183-
facade.write(body)
191+
facade.write(body, final=not more_body)
184192
if not more_body:
185193
facade.close()
186194

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Literal
4+
5+
from litestar.enums import CompressionEncoding
6+
from litestar.exceptions import MissingDependencyException
7+
from litestar.middleware.compression.facade import CompressionFacade
8+
9+
try:
10+
import zstandard as zstd
11+
except ImportError as e:
12+
raise MissingDependencyException("zstandard") from e
13+
14+
if TYPE_CHECKING:
15+
from io import BytesIO
16+
17+
from litestar.config.compression import CompressionConfig
18+
19+
20+
class ZstdCompression(CompressionFacade):
21+
__slots__ = ("buffer", "cctx", "compression_encoding", "compressor")
22+
23+
encoding = CompressionEncoding("zstd")
24+
25+
def __init__(self, buffer: BytesIO, compression_encoding: Literal["zstd"] | str, config: CompressionConfig) -> None:
26+
self.buffer = buffer
27+
self.compression_encoding = compression_encoding
28+
self.cctx = zstd.ZstdCompressor(level=config.zstd_compress_level)
29+
self.compressor = self.cctx.stream_writer(buffer)
30+
31+
def write(self, body: bytes | bytearray, final: bool = False) -> None:
32+
self.compressor.write(body)
33+
if final:
34+
self.compressor.flush(zstd.FLUSH_FRAME)
35+
else:
36+
self.compressor.flush(zstd.FLUSH_BLOCK)
37+
38+
def close(self) -> None:
39+
self.compressor.flush(zstd.FLUSH_FRAME)

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ dependencies = [
5656
annotated-types = ["annotated-types"]
5757
attrs = ["attrs"]
5858
brotli = ["brotli"]
59+
zstd = [
60+
# Only needed for <3.14, since >=3.14 ships with stdlib support
61+
"zstandard>=0.24.0; python_version < '3.14'",
62+
]
5963
cli = ["jsbeautifier", "uvicorn[standard]", "uvloop>=0.18.0; sys_platform != 'win32'"]
6064
cryptography = ["cryptography"]
6165
full = [
@@ -80,7 +84,8 @@ full = [
8084
structlog, \
8185
valkey, \
8286
htmx, \
83-
yaml]; \
87+
yaml, \
88+
zstd]; \
8489
python_version < \"3.13\"""",
8590

8691
# full extra without picologging
@@ -104,7 +109,8 @@ full = [
104109
structlog, \
105110
valkey, \
106111
htmx, \
107-
yaml]; \
112+
yaml, \
113+
zstd]; \
108114
python_version >= \"3.13\"""",
109115
]
110116
htmx = ["litestar-htmx>=0.4.0"]
@@ -329,6 +335,8 @@ module = [
329335
"exceptiongroup",
330336
"picologging",
331337
"picologging.*",
338+
"zstandard",
339+
"zstandard.*"
332340
]
333341

334342
[[tool.mypy.overrides]]

tests/e2e/test_logging/test_structlog_to_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def handler() -> str:
6161
"headers": {
6262
"host": "testserver.local",
6363
"accept": "*/*",
64-
"accept-encoding": "gzip, deflate, br",
64+
"accept-encoding": "gzip, deflate, br, zstd",
6565
"connection": "keep-alive",
6666
"user-agent": "testclient",
6767
},

tests/unit/test_connection/test_request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
171171
"headers": {
172172
"host": "example.org",
173173
"user-agent": "testclient",
174-
"accept-encoding": "gzip, deflate, br",
174+
"accept-encoding": "gzip, deflate, br, zstd",
175175
"accept": "*/*",
176176
"connection": "keep-alive",
177177
}

0 commit comments

Comments
 (0)