Skip to content

Commit

Permalink
Use zmq-anyio
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 14, 2024
1 parent 8c8d7d2 commit 1fe492a
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 235 deletions.
4 changes: 2 additions & 2 deletions ipykernel/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def _send_request(self, msg):
self.log.debug("DEBUGPYCLIENT:")
self.log.debug(self.routing_id)
self.log.debug(buf)
await self.debugpy_socket.send_multipart((self.routing_id, buf))
await self.debugpy_socket.asend_multipart((self.routing_id, buf))

async def _wait_for_response(self):
# Since events are never pushed to the message_queue
Expand Down Expand Up @@ -437,7 +437,7 @@ async def start(self):
(self.shell_socket.getsockopt(ROUTING_ID)),
)

msg = await self.shell_socket.recv_multipart()
msg = await self.shell_socket.arecv_multipart()
ident, msg = self.session.feed_identities(msg, copy=True)
try:
msg = self.session.deserialize(msg, content=True, copy=True)
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class InProcessKernel(IPythonKernel):
_underlying_iopub_socket = Instance(DummySocket, (False,))
iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment]

shell_socket = Instance(DummySocket, (True,)) # type:ignore[arg-type]
shell_socket = Instance(DummySocket, (True,))

@default("iopub_thread")
def _default_iopub_thread(self):
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/inprocess/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class Session(_Session):
async def recv(self, socket, copy=True):
return await socket.recv_multipart()
return await socket.arecv_multipart()

def send(
self,
Expand Down
72 changes: 34 additions & 38 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Callable

import zmq
import zmq_anyio
from anyio import create_task_group, run, sleep, to_thread
from jupyter_client.session import extract_header

Expand Down Expand Up @@ -55,11 +56,11 @@ def run(self):
run(self._main)

async def _main(self):
async with create_task_group() as tg:
async with create_task_group() as self._task_group:
for task in self._tasks:
tg.start_soon(task)
self._task_group.start_soon(task)
await to_thread.run_sync(self.__stop.wait)
tg.cancel_scope.cancel()
self._task_group.cancel_scope.cancel()

def stop(self):
"""Stop the thread.
Expand All @@ -78,7 +79,7 @@ class IOPubThread:
whose IO is always run in a thread.
"""

def __init__(self, socket, pipe=False):
def __init__(self, socket: zmq_anyio.Socket, pipe=False):
"""Create IOPub thread
Parameters
Expand All @@ -91,10 +92,7 @@ def __init__(self, socket, pipe=False):
"""
# ensure all of our sockets as sync zmq.Sockets
# don't create async wrappers until we are within the appropriate coroutines
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
if self.socket.context is None:
# bug in pyzmq, shadow socket doesn't always inherit context attribute
self.socket.context = socket.context # type:ignore[unreachable]
self.socket: zmq_anyio.Socket = socket
self._context = socket.context

self.background_socket = BackgroundSocket(self)
Expand All @@ -108,14 +106,14 @@ def __init__(self, socket, pipe=False):
self._event_pipe_gc_lock: threading.Lock = threading.Lock()
self._event_pipe_gc_seconds: float = 10
self._setup_event_pipe()
tasks = [self._handle_event, self._run_event_pipe_gc]
tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start]
if pipe:
tasks.append(self._handle_pipe_msgs)
self.thread = _IOPubThread(tasks)

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in0 = self._context.socket(zmq.PULL)
self._pipe_in0.linger = 0

_uuid = b2a_hex(os.urandom(16)).decode("ascii")
Expand Down Expand Up @@ -150,7 +148,7 @@ def _event_pipe(self):
except AttributeError:
# new thread, new event pipe
# create sync base socket
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
event_pipe = self._context.socket(zmq.PUSH)
event_pipe.linger = 0
event_pipe.connect(self._event_interface)
self._local.event_pipe = event_pipe
Expand All @@ -169,30 +167,28 @@ async def _handle_event(self):
Whenever *an* event arrives on the event stream,
*all* waiting events are processed in order.
"""
# create async wrapper within coroutine
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
try:
while True:
await pipe_in.recv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
for _ in range(n_events):
event_f = self._events.popleft()
event_f()
except Exception:
if self.thread.__stop.is_set():
return
raise
pipe_in = zmq_anyio.Socket(self._pipe_in0)
async with pipe_in:
try:
while True:
await pipe_in.arecv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
for _ in range(n_events):
event_f = self._events.popleft()
event_f()
except Exception:
if self.thread.__stop.is_set():
return
raise

def _setup_pipe_in(self):
"""setup listening pipe for IOPub from forked subprocesses"""
ctx = self._context

# use UUID to authenticate pipe messages
self._pipe_uuid = os.urandom(16)

self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in1 = zmq_anyio.Socket(self._context.socket(zmq.PULL))
self._pipe_in1.linger = 0

try:
Expand All @@ -210,18 +206,18 @@ def _setup_pipe_in(self):
async def _handle_pipe_msgs(self):
"""handle pipe messages from a subprocess"""
# create async wrapper within coroutine
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
try:
while True:
await self._handle_pipe_msg()
except Exception:
if self.thread.__stop.is_set():
return
raise
async with self._pipe_in1:
try:
while True:
await self._handle_pipe_msg()
except Exception:
if self.thread.__stop.is_set():
return
raise

async def _handle_pipe_msg(self, msg=None):
"""handle a pipe message from a subprocess"""
msg = msg or await self._async_pipe_in1.recv_multipart()
msg = msg or await self._pipe_in1.arecv_multipart()
if not self._pipe_flag or not self._is_main_process():
return
if msg[0] != self._pipe_uuid:
Expand Down
9 changes: 5 additions & 4 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataclasses import dataclass

import comm
import zmq.asyncio
import zmq_anyio
from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread
from anyio.abc import TaskStatus
from IPython.core import release
Expand Down Expand Up @@ -76,7 +76,7 @@ class IPythonKernel(KernelBase):
help="Set this flag to False to deactivate the use of experimental IPython completion APIs.",
).tag(config=True)

debugpy_socket = Instance(zmq.asyncio.Socket, allow_none=True)
debugpy_socket = Instance(zmq_anyio.Socket, allow_none=True)

user_module = Any()

Expand Down Expand Up @@ -212,7 +212,8 @@ def __init__(self, **kwargs):
}

async def process_debugpy(self):
async with create_task_group() as tg:
assert self.debugpy_socket is not None
async with self.debug_shell_socket, self.debugpy_socket, create_task_group() as tg:
tg.start_soon(self.receive_debugpy_messages)
tg.start_soon(self.poll_stopped_queue)
await to_thread.run_sync(self.debugpy_stop.wait)
Expand All @@ -235,7 +236,7 @@ async def receive_debugpy_message(self, msg=None):

if msg is None:
assert self.debugpy_socket is not None
msg = await self.debugpy_socket.recv_multipart()
msg = await self.debugpy_socket.arecv_multipart()
# The first frame is the socket id, we can drop it
frame = msg[1].decode("utf-8")
self.log.debug("Debugpy received: %s", frame)
Expand Down
59 changes: 11 additions & 48 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pathlib import Path

import zmq
import zmq.asyncio
import zmq_anyio
from anyio import create_task_group, run
from IPython.core.application import ( # type:ignore[attr-defined]
BaseIPythonApplication,
Expand Down Expand Up @@ -325,15 +325,15 @@ def init_sockets(self):
"""Create a context, a session, and the kernel sockets."""
self.log.info("Starting the kernel at pid: %i", os.getpid())
assert self.context is None, "init_sockets cannot be called twice!"
self.context = context = zmq.asyncio.Context()
self.context = context = zmq.Context()
atexit.register(self.close)

self.shell_socket = context.socket(zmq.ROUTER)
self.shell_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER))
self.shell_socket.linger = 1000
self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

self.stdin_socket = zmq.Context(context).socket(zmq.ROUTER)
self.stdin_socket = context.socket(zmq.ROUTER)
self.stdin_socket.linger = 1000
self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)
Expand All @@ -349,18 +349,19 @@ def init_sockets(self):

def init_control(self, context):
"""Initialize the control channel."""
self.control_socket = context.socket(zmq.ROUTER)
self.control_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER))
self.control_socket.linger = 1000
self.control_port = self._bind_socket(self.control_socket, self.control_port)
self.log.debug("control ROUTER Channel on port: %i" % self.control_port)

self.debugpy_socket = context.socket(zmq.STREAM)
self.debugpy_socket = zmq_anyio.Socket(context, zmq.STREAM)
self.debugpy_socket.linger = 1000

self.debug_shell_socket = context.socket(zmq.DEALER)
self.debug_shell_socket = zmq_anyio.Socket(context.socket(zmq.DEALER))
self.debug_shell_socket.linger = 1000
if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT):
self.debug_shell_socket.connect(self.shell_socket.getsockopt(zmq.LAST_ENDPOINT))
last_endpoint = self.shell_socket.getsockopt(zmq.LAST_ENDPOINT)
if last_endpoint:
self.debug_shell_socket.connect(last_endpoint)

if hasattr(zmq, "ROUTER_HANDOVER"):
# set router-handover to workaround zeromq reconnect problems
Expand All @@ -373,7 +374,7 @@ def init_control(self, context):

def init_iopub(self, context):
"""Initialize the iopub channel."""
self.iopub_socket = context.socket(zmq.PUB)
self.iopub_socket = zmq_anyio.Socket(context.socket(zmq.PUB))
self.iopub_socket.linger = 1000
self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)
Expand Down Expand Up @@ -634,43 +635,6 @@ def configure_tornado_logger(self):
handler.setFormatter(formatter)
logger.addHandler(handler)

def _init_asyncio_patch(self):
"""set default asyncio policy to be compatible with tornado
Tornado 6 (at least) is not compatible with the default
asyncio implementation on Windows
Pick the older SelectorEventLoopPolicy on Windows
if the known-incompatible default policy is in use.
Support for Proactor via a background thread is available in tornado 6.1,
but it is still preferable to run the Selector in the main thread
instead of the background.
do this as early as possible to make it a low priority and overridable
ref: https://github.com/tornadoweb/tornado/issues/2608
FIXME: if/when tornado supports the defaults in asyncio without threads,
remove and bump tornado requirement for py38.
Most likely, this will mean a new Python version
where asyncio.ProactorEventLoop supports add_reader and friends.
"""
if sys.platform.startswith("win"):
import asyncio

try:
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
except ImportError:
pass
# not affected
else:
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
# fallback to the pre-3.8 default of Selector
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())

def init_pdb(self):
"""Replace pdb with IPython's version that is interruptible.
Expand All @@ -690,7 +654,6 @@ def init_pdb(self):
@catch_config_error
def initialize(self, argv=None):
"""Initialize the application."""
self._init_asyncio_patch()
super().initialize(argv)
if self.subapp is not None:
return
Expand Down
Loading

0 comments on commit 1fe492a

Please sign in to comment.