Skip to content

Commit

Permalink
fix mixture of sync/async sockets in IOPubThread (#1275)
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk authored Oct 26, 2024
1 parent 8cc1ee3 commit bf10447
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 46 deletions.
4 changes: 3 additions & 1 deletion ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import sys
from contextlib import contextmanager
from typing import cast

from anyio import TASK_STATUS_IGNORED
from anyio.abc import TaskStatus
Expand Down Expand Up @@ -146,7 +147,8 @@ def callback(msg):
assert frontend is not None
frontend.iopub_channel.call_handlers(msg)

self.iopub_thread.socket.on_recv = callback
iopub_socket = cast(DummySocket, self.iopub_thread.socket)
iopub_socket.on_recv = callback

# ------ Trait initializers -----------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions ipykernel/inprocess/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ async def poll(self, timeout=0):
assert timeout == 0
statistics = self.in_receive_stream.statistics()
return statistics.current_buffer_used != 0

def close(self):
pass
75 changes: 44 additions & 31 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

from __future__ import annotations

import atexit
import contextvars
import io
Expand All @@ -15,7 +17,7 @@
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import Event, Thread, local
from typing import Any, Callable, Optional
from typing import Any, Callable

import zmq
from anyio import create_task_group, run, sleep, to_thread
Expand All @@ -25,8 +27,8 @@
# Globals
# -----------------------------------------------------------------------------

MASTER = 0
CHILD = 1
_PARENT = 0
_CHILD = 1

PIPE_BUFFER_SIZE = 1000

Expand Down Expand Up @@ -87,9 +89,16 @@ def __init__(self, socket, pipe=False):
Whether this process should listen for IOPub messages
piped from subprocesses.
"""
self.socket = socket
# ensure all of our sockets as sync zmq.Sockets
# don't create async wrappers until we are within the appropriate coroutines
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
if self.socket.context is None:
# bug in pyzmq, shadow socket doesn't always inherit context attribute
self.socket.context = socket.context # type:ignore[unreachable]
self._context = socket.context

self.background_socket = BackgroundSocket(self)
self._master_pid = os.getpid()
self._main_pid = os.getpid()
self._pipe_flag = pipe
if pipe:
self._setup_pipe_in()
Expand All @@ -106,8 +115,7 @@ def __init__(self, socket, pipe=False):

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
ctx = self.socket.context
self._pipe_in0 = ctx.socket(zmq.PULL)
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in0.linger = 0

_uuid = b2a_hex(os.urandom(16)).decode("ascii")
Expand Down Expand Up @@ -141,8 +149,8 @@ def _event_pipe(self):
event_pipe = self._local.event_pipe
except AttributeError:
# new thread, new event pipe
ctx = zmq.Context(self.socket.context)
event_pipe = ctx.socket(zmq.PUSH)
# create sync base socket
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
event_pipe.linger = 0
event_pipe.connect(self._event_interface)
self._local.event_pipe = event_pipe
Expand All @@ -161,9 +169,11 @@ async def _handle_event(self):
Whenever *an* event arrives on the event stream,
*all* waiting events are processed in order.
"""
# create async wrapper within coroutine
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
try:
while True:
await self._pipe_in0.recv()
await pipe_in.recv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
Expand All @@ -177,12 +187,12 @@ async def _handle_event(self):

def _setup_pipe_in(self):
"""setup listening pipe for IOPub from forked subprocesses"""
ctx = self.socket.context
ctx = self._context

# use UUID to authenticate pipe messages
self._pipe_uuid = os.urandom(16)

self._pipe_in1 = ctx.socket(zmq.PULL)
self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in1.linger = 0

try:
Expand All @@ -199,6 +209,8 @@ def _setup_pipe_in(self):

async def _handle_pipe_msgs(self):
"""handle pipe messages from a subprocess"""
# create async wrapper within coroutine
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
try:
while True:
await self._handle_pipe_msg()
Expand All @@ -209,8 +221,8 @@ async def _handle_pipe_msgs(self):

async def _handle_pipe_msg(self, msg=None):
"""handle a pipe message from a subprocess"""
msg = msg or await self._pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_master_process():
msg = msg or await self._async_pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_main_process():
return
if msg[0] != self._pipe_uuid:
print("Bad pipe message: %s", msg, file=sys.__stderr__)
Expand All @@ -225,14 +237,14 @@ def _setup_pipe_out(self):
pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port)
return ctx, pipe_out

def _is_master_process(self):
return os.getpid() == self._master_pid
def _is_main_process(self):
return os.getpid() == self._main_pid

def _check_mp_mode(self):
"""check for forks, and switch to zmq pipeline if necessary"""
if not self._pipe_flag or self._is_master_process():
return MASTER
return CHILD
if not self._pipe_flag or self._is_main_process():
return _PARENT
return _CHILD

def start(self):
"""Start the IOPub thread"""
Expand Down Expand Up @@ -265,7 +277,8 @@ def close(self):
self._pipe_in0.close()
if self._pipe_flag:
self._pipe_in1.close()
self.socket.close()
if self.socket is not None:
self.socket.close()
self.socket = None

@property
Expand Down Expand Up @@ -301,12 +314,12 @@ def _really_send(self, msg, *args, **kwargs):
return

mp_mode = self._check_mp_mode()

if mp_mode != CHILD:
# we are master, do a regular send
if mp_mode != _CHILD:
# we are the main parent process, do a regular send
assert self.socket is not None
self.socket.send_multipart(msg, *args, **kwargs)
else:
# we are a child, pipe to master
# we are a child, pipe to parent process
# new context/socket for every pipe-out
# since forks don't teardown politely, use ctx.term to ensure send has completed
ctx, pipe_out = self._setup_pipe_out()
Expand Down Expand Up @@ -379,7 +392,7 @@ class OutStream(TextIOBase):
flush_interval = 0.2
topic = None
encoding = "UTF-8"
_exc: Optional[Any] = None
_exc: Any = None

def fileno(self):
"""
Expand Down Expand Up @@ -477,7 +490,7 @@ def __init__(
self._thread_to_parent = {}
self._thread_to_parent_header = {}
self._parent_header_global = {}
self._master_pid = os.getpid()
self._main_pid = os.getpid()
self._flush_pending = False
self._subprocess_flush_pending = False
self._buffer_lock = threading.RLock()
Expand Down Expand Up @@ -569,8 +582,8 @@ def _setup_stream_redirects(self, name):
self.watch_fd_thread.daemon = True
self.watch_fd_thread.start()

def _is_master_process(self):
return os.getpid() == self._master_pid
def _is_main_process(self):
return os.getpid() == self._main_pid

def set_parent(self, parent):
"""Set the parent header."""
Expand Down Expand Up @@ -674,7 +687,7 @@ def _flush(self):
ident=self.topic,
)

def write(self, string: str) -> Optional[int]: # type:ignore[override]
def write(self, string: str) -> int:
"""Write to current stream after encoding if necessary
Returns
Expand All @@ -700,15 +713,15 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
msg = "I/O operation on closed file"
raise ValueError(msg)

is_child = not self._is_master_process()
is_child = not self._is_main_process()
# only touch the buffer in the IO thread to avoid races
with self._buffer_lock:
self._buffers[frozenset(parent.items())].write(string)
if is_child:
# mp.Pool cannot be trusted to flush promptly (or ever),
# and this helps.
if self._subprocess_flush_pending:
return None
return 0
self._subprocess_flush_pending = True
# We can not rely on self._io_loop.call_later from a subprocess
self.pub_thread.schedule(self._flush)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import zmq.asyncio
from jupyter_client.session import Session

from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream
from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream


@pytest.fixture()
Expand Down Expand Up @@ -73,7 +73,7 @@ async def test_io_thread(anyio_backend, iopub_thread):
ctx1, pipe = thread._setup_pipe_out()
pipe.close()
thread._pipe_in1.close()
thread._check_mp_mode = lambda: MASTER
thread._check_mp_mode = lambda: _PARENT
thread._really_send([b"hi"])
ctx1.destroy()
thread.stop()
Expand Down
24 changes: 12 additions & 12 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)


def _check_master(kc, expected=True, stream="stdout"):
def _check_main(kc, expected=True, stream="stdout"):
execute(kc=kc, code="import sys")
flush_channels(kc)
msg_id, content = execute(kc=kc, code="print(sys.%s._is_master_process())" % stream)
msg_id, content = execute(kc=kc, code="print(sys.%s._is_main_process())" % stream)
stdout, stderr = assemble_output(kc.get_iopub_msg)
assert stdout.strip() == repr(expected)

Expand All @@ -56,7 +56,7 @@ def test_simple_print():
stdout, stderr = assemble_output(kc.get_iopub_msg)
assert stdout == "hi\n"
assert stderr == ""
_check_master(kc, expected=True)
_check_main(kc, expected=True)


def test_print_to_correct_cell_from_thread():
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_capture_fd():
stdout, stderr = assemble_output(iopub)
assert stdout == "capsys\n"
assert stderr == ""
_check_master(kc, expected=True)
_check_main(kc, expected=True)


@pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing")
Expand All @@ -182,7 +182,7 @@ def test_subprocess_peek_at_stream_fileno():
stdout, stderr = assemble_output(iopub)
assert stdout == "CAP1\nCAP2\n"
assert stderr == ""
_check_master(kc, expected=True)
_check_main(kc, expected=True)


def test_sys_path():
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_sys_path_profile_dir():
def test_subprocess_print():
"""printing from forked mp.Process"""
with new_kernel() as kc:
_check_master(kc, expected=True)
_check_main(kc, expected=True)
flush_channels(kc)
np = 5
code = "\n".join(
Expand All @@ -238,8 +238,8 @@ def test_subprocess_print():
for n in range(np):
assert stdout.count(str(n)) == 1, stdout
assert stderr == ""
_check_master(kc, expected=True)
_check_master(kc, expected=True, stream="stderr")
_check_main(kc, expected=True)
_check_main(kc, expected=True, stream="stderr")


@flaky(max_runs=3)
Expand All @@ -261,8 +261,8 @@ def test_subprocess_noprint():
assert stdout == ""
assert stderr == ""

_check_master(kc, expected=True)
_check_master(kc, expected=True, stream="stderr")
_check_main(kc, expected=True)
_check_main(kc, expected=True, stream="stderr")


@flaky(max_runs=3)
Expand All @@ -287,8 +287,8 @@ def test_subprocess_error():
assert stdout == ""
assert "ValueError" in stderr

_check_master(kc, expected=True)
_check_master(kc, expected=True, stream="stderr")
_check_main(kc, expected=True)
_check_main(kc, expected=True, stream="stderr")


# raw_input tests
Expand Down

0 comments on commit bf10447

Please sign in to comment.