Skip to content

Commit

Permalink
Fix thread/task safety for WS (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ousret authored Oct 13, 2024
1 parent bd010e9 commit bd0d634
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 124 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
2.10.904 (2024-10-13)
=====================

- Fixed thread/task safety with WebSocket R/W operations.
- Fixed missing propagation of callbacks (e.g. ``on_post_connection``) in retries of failed requests.

2.10.903 (2024-10-12)
=====================

Expand Down
3 changes: 3 additions & 0 deletions src/urllib3/_async/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,9 @@ async def urlopen(
body_pos=body_pos,
preload_content=preload_content,
decode_content=decode_content,
on_early_response=on_early_response,
on_upload_body=on_upload_body,
on_post_connection=on_post_connection,
multiplexed=multiplexed,
**response_kw,
)
Expand Down
2 changes: 1 addition & 1 deletion src/urllib3/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This file is protected via CODEOWNERS
from __future__ import annotations

__version__ = "2.10.903"
__version__ = "2.10.904"
3 changes: 3 additions & 0 deletions src/urllib3/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,9 @@ def urlopen(
body_pos=body_pos,
preload_content=preload_content,
decode_content=decode_content,
on_early_response=on_early_response,
on_upload_body=on_upload_body,
on_post_connection=on_post_connection,
multiplexed=multiplexed,
**response_kw,
)
Expand Down
6 changes: 3 additions & 3 deletions src/urllib3/contrib/webextensions/_async/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ async def send_payload(self, buf: str | bytes) -> None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")

async with self._police_officer.borrow(self._response):
if isinstance(buf, str):
buf = buf.encode()
if isinstance(buf, str):
buf = buf.encode()

async with self._police_officer.borrow(self._response):
async with self._write_error_catcher():
await self._dsa.sendall(buf)
119 changes: 63 additions & 56 deletions src/urllib3/contrib/webextensions/_async/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,22 @@ def headers(self, http_version: HttpVersion) -> dict[str, str]:
async def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(CloseConnection(0))
except WebSocketProtocolError as e:
await self.close()
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
await self._dsa.close()
self._dsa = None
if self._police_officer is not None:
async with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(
CloseConnection(0)
)
except WebSocketProtocolError as e:
await self.close()
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
await self._dsa.close()
self._dsa = None
else:
self._dsa = None
if self._response is not None:
if self._police_officer is not None:
self._police_officer.forget(self._response)
Expand All @@ -126,35 +132,7 @@ async def next_payload(self) -> str | bytes | None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
await self.close()
raise ProtocolError from e

async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)

while True:
async with self._police_officer.borrow(self._response):
async with self._read_error_catcher():
data, eot, _ = await self._dsa.recv_extended(None)

try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
raise ProtocolError from e

async with self._police_officer.borrow(self._response):
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
Expand All @@ -165,40 +143,69 @@ async def next_payload(self) -> str | bytes | None:
await self.close()
return None
elif isinstance(event, Ping):
data_to_send = self._protocol.send(event.response())
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
await self.close()
raise ProtocolError from e

async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue

while True:
async with self._read_error_catcher():
data, eot, _ = await self._dsa.recv_extended(None)

try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
raise ProtocolError from e

for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
data_to_send = self._protocol.send(event.response())
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue

async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
raise ProtocolError from e

async with self._police_officer.borrow(self._response):
try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
raise ProtocolError from e

async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)

async def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
raise ProtocolError from e
async with self._police_officer.borrow(self._response):
try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
raise ProtocolError from e

async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)

@staticmethod
def supported_schemes() -> set[str]:
Expand Down
6 changes: 3 additions & 3 deletions src/urllib3/contrib/webextensions/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def send_payload(self, buf: str | bytes) -> None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")

with self._police_officer.borrow(self._response):
if isinstance(buf, str):
buf = buf.encode()
if isinstance(buf, str):
buf = buf.encode()

with self._police_officer.borrow(self._response):
with self._write_error_catcher():
self._dsa.sendall(buf)
130 changes: 69 additions & 61 deletions src/urllib3/contrib/webextensions/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,22 @@ def headers(self, http_version: HttpVersion) -> dict[str, str]:
def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(CloseConnection(0))
except WebSocketProtocolError as e:
raise ProtocolError from e

with self._write_error_catcher():
self._dsa.sendall(data_to_send)
self._dsa.close()
self._dsa = None
if self._police_officer is not None:
with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(
CloseConnection(0)
)
except WebSocketProtocolError as e:
raise ProtocolError from e

with self._write_error_catcher():
self._dsa.sendall(data_to_send)
self._dsa.close()
self._dsa = None
else:
self._dsa = None
if self._response is not None:
if self._police_officer is not None:
self._police_officer.forget(self._response)
Expand All @@ -126,37 +132,8 @@ def next_payload(self) -> str | bytes | None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

# we may have pending event to unpack!
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

with self._write_error_catcher():
self._dsa.sendall(data_to_send)

while True:
with self._police_officer.borrow(self._response):
with self._read_error_catcher():
data, eot, _ = self._dsa.recv_extended(None)

try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

with self._police_officer.borrow(self._response):
# we may have pending event to unpack!
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
Expand All @@ -168,46 +145,77 @@ def next_payload(self) -> str | bytes | None:
return None
elif isinstance(event, Ping):
try:
data_to_send = self._protocol.send(event.response())
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

with self._write_error_catcher():
self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue

while True:
with self._read_error_catcher():
data, eot, _ = self._dsa.recv_extended(None)

try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue

def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

with self._police_officer.borrow(self._response):
with self._write_error_catcher():
self._dsa.sendall(data_to_send)

def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(Ping())
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

with self._write_error_catcher():
self._dsa.sendall(data_to_send)

def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e

with self._write_error_catcher():
self._dsa.sendall(data_to_send)

@staticmethod
def supported_schemes() -> set[str]:
return {"ws", "wss"}
Expand Down

0 comments on commit bd0d634

Please sign in to comment.