Skip to content

Commit dc2674d

Browse files
committed
Replace Tornado with AnyIO
1 parent 2a8adb9 commit dc2674d

27 files changed

+879
-871
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
run: |
8686
hatch run typing:test
8787
hatch run lint:style
88-
pipx run interrogate -vv .
88+
pipx run interrogate -vv . --fail-under 90
8989
pipx run doc8 --max-line-length=200
9090
9191
check_release:
+4-36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""An in-process terminal example."""
22
import os
3-
import sys
43

5-
import tornado
4+
from anyio import run
65
from jupyter_console.ptshell import ZMQTerminalInteractiveShell
76

87
from ipykernel.inprocess.manager import InProcessKernelManager
@@ -13,46 +12,15 @@ def print_process_id():
1312
print("Process ID is:", os.getpid())
1413

1514

16-
def init_asyncio_patch():
17-
"""set default asyncio policy to be compatible with tornado
18-
Tornado 6 (at least) is not compatible with the default
19-
asyncio implementation on Windows
20-
Pick the older SelectorEventLoopPolicy on Windows
21-
if the known-incompatible default policy is in use.
22-
do this as early as possible to make it a low priority and overridable
23-
ref: https://github.com/tornadoweb/tornado/issues/2608
24-
FIXME: if/when tornado supports the defaults in asyncio,
25-
remove and bump tornado requirement for py38
26-
"""
27-
if (
28-
sys.platform.startswith("win")
29-
and sys.version_info >= (3, 8)
30-
and tornado.version_info < (6, 1)
31-
):
32-
import asyncio
33-
34-
try:
35-
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
36-
except ImportError:
37-
pass
38-
# not affected
39-
else:
40-
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
41-
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
42-
# fallback to the pre-3.8 default of Selector
43-
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
44-
45-
46-
def main():
15+
async def main():
4716
"""The main function."""
4817
print_process_id()
4918

5019
# Create an in-process kernel
5120
# >>> print_process_id()
5221
# will print the same process ID as the main process
53-
init_asyncio_patch()
5422
kernel_manager = InProcessKernelManager()
55-
kernel_manager.start_kernel()
23+
await kernel_manager.start_kernel()
5624
kernel = kernel_manager.kernel
5725
kernel.gui = "qt4"
5826
kernel.shell.push({"foo": 43, "print_process_id": print_process_id})
@@ -64,4 +32,4 @@ def main():
6432

6533

6634
if __name__ == "__main__":
67-
main()
35+
run(main)

ipykernel/control.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""A thread for a control channel."""
2-
from threading import Thread
2+
from threading import Event, Thread
33

4-
from tornado.ioloop import IOLoop
4+
from anyio import create_task_group, run, to_thread
55

66
CONTROL_THREAD_NAME = "Control"
77

@@ -12,21 +12,29 @@ class ControlThread(Thread):
1212
def __init__(self, **kwargs):
1313
"""Initialize the thread."""
1414
Thread.__init__(self, name=CONTROL_THREAD_NAME, **kwargs)
15-
self.io_loop = IOLoop(make_current=False)
1615
self.pydev_do_not_trace = True
1716
self.is_pydev_daemon_thread = True
17+
self.__stop = Event()
18+
self._task = None
19+
20+
def set_task(self, task):
21+
self._task = task
1822

1923
def run(self):
2024
"""Run the thread."""
2125
self.name = CONTROL_THREAD_NAME
22-
try:
23-
self.io_loop.start()
24-
finally:
25-
self.io_loop.close()
26+
run(self._main)
27+
28+
async def _main(self):
29+
async with create_task_group() as tg:
30+
if self._task is not None:
31+
tg.start_soon(self._task)
32+
await to_thread.run_sync(self.__stop.wait)
33+
tg.cancel_scope.cancel()
2634

2735
def stop(self):
2836
"""Stop the thread.
2937
3038
This method is threadsafe.
3139
"""
32-
self.io_loop.add_callback(self.io_loop.stop)
40+
self.__stop.set()

ipykernel/debugger.py

+41-26
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import re
44
import sys
55
import typing as t
6+
from math import inf
7+
from typing import Any
68

79
import zmq
10+
from anyio import Event, create_memory_object_stream
811
from IPython.core.getipython import get_ipython
912
from IPython.core.inputtransformer2 import leading_empty_lines
10-
from tornado.locks import Event
11-
from tornado.queues import Queue
1213
from zmq.utils import jsonapi
1314

1415
try:
@@ -116,7 +117,9 @@ def __init__(self, event_callback, log):
116117
self.tcp_buffer = ""
117118
self._reset_tcp_pos()
118119
self.event_callback = event_callback
119-
self.message_queue: Queue[t.Any] = Queue()
120+
self.message_send_stream, self.message_receive_stream = create_memory_object_stream[
121+
dict[str, Any]
122+
](max_buffer_size=inf)
120123
self.log = log
121124

122125
def _reset_tcp_pos(self):
@@ -135,7 +138,7 @@ def _put_message(self, raw_msg):
135138
else:
136139
self.log.debug("QUEUE - put message:")
137140
self.log.debug(msg)
138-
self.message_queue.put_nowait(msg)
141+
self.message_send_stream.send_nowait(msg)
139142

140143
def put_tcp_frame(self, frame):
141144
"""Put a tcp frame in the queue."""
@@ -186,25 +189,31 @@ def put_tcp_frame(self, frame):
186189

187190
async def get_message(self):
188191
"""Get a message from the queue."""
189-
return await self.message_queue.get()
192+
return await self.message_receive_stream.receive()
190193

191194

192195
class DebugpyClient:
193196
"""A client for debugpy."""
194197

195-
def __init__(self, log, debugpy_stream, event_callback):
198+
def __init__(self, log, debugpy_socket, event_callback):
196199
"""Initialize the client."""
197200
self.log = log
198-
self.debugpy_stream = debugpy_stream
201+
self.debugpy_socket = debugpy_socket
199202
self.event_callback = event_callback
200203
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
201204
self.debugpy_host = "127.0.0.1"
202205
self.debugpy_port = -1
203206
self.routing_id = None
204207
self.wait_for_attach = True
205-
self.init_event = Event()
208+
self._init_event = None
206209
self.init_event_seq = -1
207210

211+
@property
212+
def init_event(self):
213+
if self._init_event is None:
214+
self._init_event = Event()
215+
return self._init_event
216+
208217
def _get_endpoint(self):
209218
host, port = self.get_host_port()
210219
return "tcp://" + host + ":" + str(port)
@@ -215,9 +224,9 @@ def _forward_event(self, msg):
215224
self.init_event_seq = msg["seq"]
216225
self.event_callback(msg)
217226

218-
def _send_request(self, msg):
227+
async def _send_request(self, msg):
219228
if self.routing_id is None:
220-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
229+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
221230
content = jsonapi.dumps(
222231
msg,
223232
default=json_default,
@@ -232,7 +241,7 @@ def _send_request(self, msg):
232241
self.log.debug("DEBUGPYCLIENT:")
233242
self.log.debug(self.routing_id)
234243
self.log.debug(buf)
235-
self.debugpy_stream.send_multipart((self.routing_id, buf))
244+
await self.debugpy_socket.send_multipart((self.routing_id, buf))
236245

237246
async def _wait_for_response(self):
238247
# Since events are never pushed to the message_queue
@@ -250,7 +259,7 @@ async def _handle_init_sequence(self):
250259
"seq": int(self.init_event_seq) + 1,
251260
"command": "configurationDone",
252261
}
253-
self._send_request(configurationDone)
262+
await self._send_request(configurationDone)
254263

255264
# 3] Waits for configurationDone response
256265
await self._wait_for_response()
@@ -262,7 +271,7 @@ async def _handle_init_sequence(self):
262271
def get_host_port(self):
263272
"""Get the host debugpy port."""
264273
if self.debugpy_port == -1:
265-
socket = self.debugpy_stream.socket
274+
socket = self.debugpy_socket
266275
socket.bind_to_random_port("tcp://" + self.debugpy_host)
267276
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
268277
socket.unbind(self.endpoint)
@@ -272,14 +281,13 @@ def get_host_port(self):
272281

273282
def connect_tcp_socket(self):
274283
"""Connect to the tcp socket."""
275-
self.debugpy_stream.socket.connect(self._get_endpoint())
276-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
284+
self.debugpy_socket.connect(self._get_endpoint())
285+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
277286

278287
def disconnect_tcp_socket(self):
279288
"""Disconnect from the tcp socket."""
280-
self.debugpy_stream.socket.disconnect(self._get_endpoint())
289+
self.debugpy_socket.disconnect(self._get_endpoint())
281290
self.routing_id = None
282-
self.init_event = Event()
283291
self.init_event_seq = -1
284292
self.wait_for_attach = True
285293

@@ -289,7 +297,7 @@ def receive_dap_frame(self, frame):
289297

290298
async def send_dap_request(self, msg):
291299
"""Send a dap request."""
292-
self._send_request(msg)
300+
await self._send_request(msg)
293301
if self.wait_for_attach and msg["command"] == "attach":
294302
rep = await self._handle_init_sequence()
295303
self.wait_for_attach = False
@@ -325,17 +333,19 @@ class Debugger:
325333
]
326334

327335
def __init__(
328-
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
336+
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
329337
):
330338
"""Initialize the debugger."""
331339
self.log = log
332-
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
340+
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
333341
self.shell_socket = shell_socket
334342
self.session = session
335343
self.is_started = False
336344
self.event_callback = event_callback
337345
self.just_my_code = just_my_code
338-
self.stopped_queue: Queue[t.Any] = Queue()
346+
self.stopped_send_stream, self.stopped_receive_stream = create_memory_object_stream[
347+
dict[str, Any]
348+
](max_buffer_size=inf)
339349

340350
self.started_debug_handlers = {}
341351
for msg_type in Debugger.started_debug_msg_types:
@@ -360,7 +370,7 @@ def __init__(
360370
def _handle_event(self, msg):
361371
if msg["event"] == "stopped":
362372
if msg["body"]["allThreadsStopped"]:
363-
self.stopped_queue.put_nowait(msg)
373+
self.stopped_send_stream.send_nowait(msg)
364374
# Do not forward the event now, will be done in the handle_stopped_event
365375
return
366376
else:
@@ -400,7 +410,7 @@ async def handle_stopped_event(self):
400410
"""Handle a stopped event."""
401411
# Wait for a stopped event message in the stopped queue
402412
# This message is used for triggering the 'threads' request
403-
event = await self.stopped_queue.get()
413+
event = await self.stopped_receive_stream.receive()
404414
req = {"seq": event["seq"] + 1, "type": "request", "command": "threads"}
405415
rep = await self._forward_message(req)
406416
for thread in rep["body"]["threads"]:
@@ -412,7 +422,7 @@ async def handle_stopped_event(self):
412422
def tcp_client(self):
413423
return self.debugpy_client
414424

415-
def start(self):
425+
async def start(self):
416426
"""Start the debugger."""
417427
if not self.debugpy_initialized:
418428
tmp_dir = get_tmp_directory()
@@ -430,7 +440,12 @@ def start(self):
430440
(self.shell_socket.getsockopt(ROUTING_ID)),
431441
)
432442

433-
ident, msg = self.session.recv(self.shell_socket, mode=0)
443+
msg = await self.shell_socket.recv_multipart()
444+
ident, msg = self.session.feed_identities(msg, copy=True)
445+
try:
446+
msg = self.session.deserialize(msg, content=True, copy=True)
447+
except Exception:
448+
self.log.error("Invalid message", exc_info=True)
434449
self.debugpy_initialized = msg["content"]["status"] == "ok"
435450

436451
# Don't remove leading empty lines when debugging so the breakpoints are correctly positioned
@@ -719,7 +734,7 @@ async def process_request(self, message):
719734
if self.is_started:
720735
self.log.info("The debugger has already started")
721736
else:
722-
self.is_started = self.start()
737+
self.is_started = await self.start()
723738
if self.is_started:
724739
self.log.info("The debugger has started")
725740
else:

ipykernel/eventloops.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,12 @@ def loop_asyncio(kernel):
388388
loop._should_close = False # type:ignore[attr-defined]
389389

390390
# pause eventloop when there's an event on a zmq socket
391-
def process_stream_events(stream):
391+
def process_stream_events(socket):
392392
"""fall back to main loop when there's a socket event"""
393-
if stream.flush(limit=1):
394-
loop.stop()
393+
loop.stop()
395394

396-
notifier = partial(process_stream_events, kernel.shell_stream)
397-
loop.add_reader(kernel.shell_stream.getsockopt(zmq.FD), notifier)
395+
notifier = partial(process_stream_events, kernel.shell_socket)
396+
loop.add_reader(kernel.shell_socket.getsockopt(zmq.FD), notifier)
398397
loop.call_soon(notifier)
399398

400399
while True:

ipykernel/inprocess/blocking.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient):
8080
iopub_channel_class = Type(BlockingInProcessChannel) # type:ignore[arg-type]
8181
stdin_channel_class = Type(BlockingInProcessStdInChannel) # type:ignore[arg-type]
8282

83-
def wait_for_ready(self):
83+
async def wait_for_ready(self):
8484
"""Wait for kernel info reply on shell channel."""
8585
while True:
86-
self.kernel_info()
86+
await self.kernel_info()
8787
try:
8888
msg = self.shell_channel.get_msg(block=True, timeout=1)
8989
except Empty:
@@ -103,6 +103,5 @@ def wait_for_ready(self):
103103
while True:
104104
try:
105105
msg = self.iopub_channel.get_msg(block=True, timeout=0.2)
106-
print(msg["msg_type"])
107106
except Empty:
108107
break

0 commit comments

Comments
 (0)