Skip to content

Commit b2bb553

Browse files
committed
Replace Tornado with AnyIO
1 parent 2a8adb9 commit b2bb553

27 files changed

+881
-871
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
run: |
8686
hatch run typing:test
8787
hatch run lint:style
88-
pipx run interrogate -vv .
88+
pipx run interrogate -vv . --fail-under 90
8989
pipx run doc8 --max-line-length=200
9090
9191
check_release:
+4-36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""An in-process terminal example."""
22
import os
3-
import sys
43

5-
import tornado
4+
from anyio import run
65
from jupyter_console.ptshell import ZMQTerminalInteractiveShell
76

87
from ipykernel.inprocess.manager import InProcessKernelManager
@@ -13,46 +12,15 @@ def print_process_id():
1312
print("Process ID is:", os.getpid())
1413

1514

16-
def init_asyncio_patch():
17-
"""set default asyncio policy to be compatible with tornado
18-
Tornado 6 (at least) is not compatible with the default
19-
asyncio implementation on Windows
20-
Pick the older SelectorEventLoopPolicy on Windows
21-
if the known-incompatible default policy is in use.
22-
do this as early as possible to make it a low priority and overridable
23-
ref: https://github.com/tornadoweb/tornado/issues/2608
24-
FIXME: if/when tornado supports the defaults in asyncio,
25-
remove and bump tornado requirement for py38
26-
"""
27-
if (
28-
sys.platform.startswith("win")
29-
and sys.version_info >= (3, 8)
30-
and tornado.version_info < (6, 1)
31-
):
32-
import asyncio
33-
34-
try:
35-
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
36-
except ImportError:
37-
pass
38-
# not affected
39-
else:
40-
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
41-
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
42-
# fallback to the pre-3.8 default of Selector
43-
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
44-
45-
46-
def main():
15+
async def main():
4716
"""The main function."""
4817
print_process_id()
4918

5019
# Create an in-process kernel
5120
# >>> print_process_id()
5221
# will print the same process ID as the main process
53-
init_asyncio_patch()
5422
kernel_manager = InProcessKernelManager()
55-
kernel_manager.start_kernel()
23+
await kernel_manager.start_kernel()
5624
kernel = kernel_manager.kernel
5725
kernel.gui = "qt4"
5826
kernel.shell.push({"foo": 43, "print_process_id": print_process_id})
@@ -64,4 +32,4 @@ def main():
6432

6533

6634
if __name__ == "__main__":
67-
main()
35+
run(main)

ipykernel/control.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""A thread for a control channel."""
2-
from threading import Thread
2+
from threading import Event, Thread
33

4-
from tornado.ioloop import IOLoop
4+
from anyio import create_task_group, run, to_thread
55

66
CONTROL_THREAD_NAME = "Control"
77

@@ -12,21 +12,29 @@ class ControlThread(Thread):
1212
def __init__(self, **kwargs):
1313
"""Initialize the thread."""
1414
Thread.__init__(self, name=CONTROL_THREAD_NAME, **kwargs)
15-
self.io_loop = IOLoop(make_current=False)
1615
self.pydev_do_not_trace = True
1716
self.is_pydev_daemon_thread = True
17+
self.__stop = Event()
18+
self._task = None
19+
20+
def set_task(self, task):
21+
self._task = task
1822

1923
def run(self):
2024
"""Run the thread."""
2125
self.name = CONTROL_THREAD_NAME
22-
try:
23-
self.io_loop.start()
24-
finally:
25-
self.io_loop.close()
26+
run(self._main)
27+
28+
async def _main(self):
29+
async with create_task_group() as tg:
30+
if self._task is not None:
31+
tg.start_soon(self._task)
32+
await to_thread.run_sync(self.__stop.wait)
33+
tg.cancel_scope.cancel()
2634

2735
def stop(self):
2836
"""Stop the thread.
2937
3038
This method is threadsafe.
3139
"""
32-
self.io_loop.add_callback(self.io_loop.stop)
40+
self.__stop.set()

ipykernel/debugger.py

+43-26
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
"""Debugger implementation for the IPython kernel."""
2+
from __future__ import annotations
3+
24
import os
35
import re
46
import sys
57
import typing as t
8+
from math import inf
9+
from typing import Any
610

711
import zmq
12+
from anyio import Event, create_memory_object_stream
813
from IPython.core.getipython import get_ipython
914
from IPython.core.inputtransformer2 import leading_empty_lines
10-
from tornado.locks import Event
11-
from tornado.queues import Queue
1215
from zmq.utils import jsonapi
1316

1417
try:
@@ -116,7 +119,9 @@ def __init__(self, event_callback, log):
116119
self.tcp_buffer = ""
117120
self._reset_tcp_pos()
118121
self.event_callback = event_callback
119-
self.message_queue: Queue[t.Any] = Queue()
122+
self.message_send_stream, self.message_receive_stream = create_memory_object_stream[
123+
dict[str, Any]
124+
](max_buffer_size=inf)
120125
self.log = log
121126

122127
def _reset_tcp_pos(self):
@@ -135,7 +140,7 @@ def _put_message(self, raw_msg):
135140
else:
136141
self.log.debug("QUEUE - put message:")
137142
self.log.debug(msg)
138-
self.message_queue.put_nowait(msg)
143+
self.message_send_stream.send_nowait(msg)
139144

140145
def put_tcp_frame(self, frame):
141146
"""Put a tcp frame in the queue."""
@@ -186,25 +191,31 @@ def put_tcp_frame(self, frame):
186191

187192
async def get_message(self):
188193
"""Get a message from the queue."""
189-
return await self.message_queue.get()
194+
return await self.message_receive_stream.receive()
190195

191196

192197
class DebugpyClient:
193198
"""A client for debugpy."""
194199

195-
def __init__(self, log, debugpy_stream, event_callback):
200+
def __init__(self, log, debugpy_socket, event_callback):
196201
"""Initialize the client."""
197202
self.log = log
198-
self.debugpy_stream = debugpy_stream
203+
self.debugpy_socket = debugpy_socket
199204
self.event_callback = event_callback
200205
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
201206
self.debugpy_host = "127.0.0.1"
202207
self.debugpy_port = -1
203208
self.routing_id = None
204209
self.wait_for_attach = True
205-
self.init_event = Event()
210+
self._init_event = None
206211
self.init_event_seq = -1
207212

213+
@property
214+
def init_event(self):
215+
if self._init_event is None:
216+
self._init_event = Event()
217+
return self._init_event
218+
208219
def _get_endpoint(self):
209220
host, port = self.get_host_port()
210221
return "tcp://" + host + ":" + str(port)
@@ -215,9 +226,9 @@ def _forward_event(self, msg):
215226
self.init_event_seq = msg["seq"]
216227
self.event_callback(msg)
217228

218-
def _send_request(self, msg):
229+
async def _send_request(self, msg):
219230
if self.routing_id is None:
220-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
231+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
221232
content = jsonapi.dumps(
222233
msg,
223234
default=json_default,
@@ -232,7 +243,7 @@ def _send_request(self, msg):
232243
self.log.debug("DEBUGPYCLIENT:")
233244
self.log.debug(self.routing_id)
234245
self.log.debug(buf)
235-
self.debugpy_stream.send_multipart((self.routing_id, buf))
246+
await self.debugpy_socket.send_multipart((self.routing_id, buf))
236247

237248
async def _wait_for_response(self):
238249
# Since events are never pushed to the message_queue
@@ -250,7 +261,7 @@ async def _handle_init_sequence(self):
250261
"seq": int(self.init_event_seq) + 1,
251262
"command": "configurationDone",
252263
}
253-
self._send_request(configurationDone)
264+
await self._send_request(configurationDone)
254265

255266
# 3] Waits for configurationDone response
256267
await self._wait_for_response()
@@ -262,7 +273,7 @@ async def _handle_init_sequence(self):
262273
def get_host_port(self):
263274
"""Get the host debugpy port."""
264275
if self.debugpy_port == -1:
265-
socket = self.debugpy_stream.socket
276+
socket = self.debugpy_socket
266277
socket.bind_to_random_port("tcp://" + self.debugpy_host)
267278
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
268279
socket.unbind(self.endpoint)
@@ -272,14 +283,13 @@ def get_host_port(self):
272283

273284
def connect_tcp_socket(self):
274285
"""Connect to the tcp socket."""
275-
self.debugpy_stream.socket.connect(self._get_endpoint())
276-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
286+
self.debugpy_socket.connect(self._get_endpoint())
287+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
277288

278289
def disconnect_tcp_socket(self):
279290
"""Disconnect from the tcp socket."""
280-
self.debugpy_stream.socket.disconnect(self._get_endpoint())
291+
self.debugpy_socket.disconnect(self._get_endpoint())
281292
self.routing_id = None
282-
self.init_event = Event()
283293
self.init_event_seq = -1
284294
self.wait_for_attach = True
285295

@@ -289,7 +299,7 @@ def receive_dap_frame(self, frame):
289299

290300
async def send_dap_request(self, msg):
291301
"""Send a dap request."""
292-
self._send_request(msg)
302+
await self._send_request(msg)
293303
if self.wait_for_attach and msg["command"] == "attach":
294304
rep = await self._handle_init_sequence()
295305
self.wait_for_attach = False
@@ -325,17 +335,19 @@ class Debugger:
325335
]
326336

327337
def __init__(
328-
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
338+
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
329339
):
330340
"""Initialize the debugger."""
331341
self.log = log
332-
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
342+
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
333343
self.shell_socket = shell_socket
334344
self.session = session
335345
self.is_started = False
336346
self.event_callback = event_callback
337347
self.just_my_code = just_my_code
338-
self.stopped_queue: Queue[t.Any] = Queue()
348+
self.stopped_send_stream, self.stopped_receive_stream = create_memory_object_stream[
349+
dict[str, Any]
350+
](max_buffer_size=inf)
339351

340352
self.started_debug_handlers = {}
341353
for msg_type in Debugger.started_debug_msg_types:
@@ -360,7 +372,7 @@ def __init__(
360372
def _handle_event(self, msg):
361373
if msg["event"] == "stopped":
362374
if msg["body"]["allThreadsStopped"]:
363-
self.stopped_queue.put_nowait(msg)
375+
self.stopped_send_stream.send_nowait(msg)
364376
# Do not forward the event now, will be done in the handle_stopped_event
365377
return
366378
else:
@@ -400,7 +412,7 @@ async def handle_stopped_event(self):
400412
"""Handle a stopped event."""
401413
# Wait for a stopped event message in the stopped queue
402414
# This message is used for triggering the 'threads' request
403-
event = await self.stopped_queue.get()
415+
event = await self.stopped_receive_stream.receive()
404416
req = {"seq": event["seq"] + 1, "type": "request", "command": "threads"}
405417
rep = await self._forward_message(req)
406418
for thread in rep["body"]["threads"]:
@@ -412,7 +424,7 @@ async def handle_stopped_event(self):
412424
def tcp_client(self):
413425
return self.debugpy_client
414426

415-
def start(self):
427+
async def start(self):
416428
"""Start the debugger."""
417429
if not self.debugpy_initialized:
418430
tmp_dir = get_tmp_directory()
@@ -430,7 +442,12 @@ def start(self):
430442
(self.shell_socket.getsockopt(ROUTING_ID)),
431443
)
432444

433-
ident, msg = self.session.recv(self.shell_socket, mode=0)
445+
msg = await self.shell_socket.recv_multipart()
446+
ident, msg = self.session.feed_identities(msg, copy=True)
447+
try:
448+
msg = self.session.deserialize(msg, content=True, copy=True)
449+
except Exception:
450+
self.log.error("Invalid message", exc_info=True)
434451
self.debugpy_initialized = msg["content"]["status"] == "ok"
435452

436453
# Don't remove leading empty lines when debugging so the breakpoints are correctly positioned
@@ -719,7 +736,7 @@ async def process_request(self, message):
719736
if self.is_started:
720737
self.log.info("The debugger has already started")
721738
else:
722-
self.is_started = self.start()
739+
self.is_started = await self.start()
723740
if self.is_started:
724741
self.log.info("The debugger has started")
725742
else:

ipykernel/eventloops.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,12 @@ def loop_asyncio(kernel):
388388
loop._should_close = False # type:ignore[attr-defined]
389389

390390
# pause eventloop when there's an event on a zmq socket
391-
def process_stream_events(stream):
391+
def process_stream_events(socket):
392392
"""fall back to main loop when there's a socket event"""
393-
if stream.flush(limit=1):
394-
loop.stop()
393+
loop.stop()
395394

396-
notifier = partial(process_stream_events, kernel.shell_stream)
397-
loop.add_reader(kernel.shell_stream.getsockopt(zmq.FD), notifier)
395+
notifier = partial(process_stream_events, kernel.shell_socket)
396+
loop.add_reader(kernel.shell_socket.getsockopt(zmq.FD), notifier)
398397
loop.call_soon(notifier)
399398

400399
while True:

ipykernel/inprocess/blocking.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient):
8080
iopub_channel_class = Type(BlockingInProcessChannel) # type:ignore[arg-type]
8181
stdin_channel_class = Type(BlockingInProcessStdInChannel) # type:ignore[arg-type]
8282

83-
def wait_for_ready(self):
83+
async def wait_for_ready(self):
8484
"""Wait for kernel info reply on shell channel."""
8585
while True:
86-
self.kernel_info()
86+
await self.kernel_info()
8787
try:
8888
msg = self.shell_channel.get_msg(block=True, timeout=1)
8989
except Empty:
@@ -103,6 +103,5 @@ def wait_for_ready(self):
103103
while True:
104104
try:
105105
msg = self.iopub_channel.get_msg(block=True, timeout=0.2)
106-
print(msg["msg_type"])
107106
except Empty:
108107
break

0 commit comments

Comments
 (0)