7
7
from litestar .enums import CompressionEncoding , ScopeType
8
8
from litestar .middleware .base import AbstractMiddleware
9
9
from litestar .middleware .compression .gzip_facade import GzipCompression
10
+ from litestar .middleware .compression .zstd_facade import ZstdCompression
10
11
from litestar .utils .empty import value_or_default
11
12
from litestar .utils .scope .state import ScopeState
12
13
@@ -82,10 +83,22 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
82
83
83
84
await self .app (scope , receive , send )
84
85
86
+ def get_facade_cls (
87
+ self ,
88
+ compression_encoding : Literal [CompressionEncoding .BROTLI , CompressionEncoding .GZIP , CompressionEncoding .ZSTD ]
89
+ | str ,
90
+ ) -> type [CompressionFacade ]:
91
+ if compression_encoding == CompressionEncoding .GZIP :
92
+ return GzipCompression
93
+ if compression_encoding == CompressionEncoding .ZSTD :
94
+ return ZstdCompression
95
+ return self .config .compression_facade
96
+
85
97
def create_compression_send_wrapper (
86
98
self ,
87
99
send : Send ,
88
- compression_encoding : Literal [CompressionEncoding .BROTLI , CompressionEncoding .GZIP ] | str ,
100
+ compression_encoding : Literal [CompressionEncoding .BROTLI , CompressionEncoding .GZIP , CompressionEncoding .ZSTD ]
101
+ | str ,
89
102
scope : Scope ,
90
103
) -> Send :
91
104
"""Wrap ``send`` to handle brotli compression.
@@ -100,15 +113,10 @@ def create_compression_send_wrapper(
100
113
"""
101
114
bytes_buffer = BytesIO ()
102
115
103
- facade : CompressionFacade
104
116
# We can't use `self.config.compression_facade` directly if the compression is `gzip` since
105
117
# it may be being used as a fallback.
106
- if compression_encoding == CompressionEncoding .GZIP :
107
- facade = GzipCompression (buffer = bytes_buffer , compression_encoding = compression_encoding , config = self .config )
108
- else :
109
- facade = self .config .compression_facade (
110
- buffer = bytes_buffer , compression_encoding = compression_encoding , config = self .config
111
- )
118
+ facade_cls : type [CompressionFacade ] = self .get_facade_cls (compression_encoding )
119
+ facade = facade_cls (buffer = bytes_buffer , compression_encoding = compression_encoding , config = self .config )
112
120
113
121
initial_message : HTTPResponseStartEvent | None = None
114
122
started = False
@@ -151,7 +159,7 @@ async def send_wrapper(message: Message) -> None:
151
159
del headers ["Content-Length" ]
152
160
connection_state .response_compressed = True
153
161
154
- facade .write (body )
162
+ facade .write (body , final = not more_body )
155
163
156
164
message ["body" ] = bytes_buffer .getvalue ()
157
165
bytes_buffer .seek (0 )
@@ -160,7 +168,7 @@ async def send_wrapper(message: Message) -> None:
160
168
await send (message )
161
169
162
170
elif len (body ) >= self .config .minimum_size :
163
- facade .write (body )
171
+ facade .write (body , final = not more_body )
164
172
facade .close ()
165
173
body = bytes_buffer .getvalue ()
166
174
@@ -180,7 +188,7 @@ async def send_wrapper(message: Message) -> None:
180
188
await send (message )
181
189
182
190
else :
183
- facade .write (body )
191
+ facade .write (body , final = not more_body )
184
192
if not more_body :
185
193
facade .close ()
186
194
0 commit comments