Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mixture of sync/async sockets in IOPubThread #1275

Merged
merged 5 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import sys
from contextlib import contextmanager
from typing import cast

from anyio import TASK_STATUS_IGNORED
from anyio.abc import TaskStatus
Expand Down Expand Up @@ -146,7 +147,8 @@ def callback(msg):
assert frontend is not None
frontend.iopub_channel.call_handlers(msg)

self.iopub_thread.socket.on_recv = callback
iopub_socket = cast(DummySocket, self.iopub_thread.socket)
iopub_socket.on_recv = callback

# ------ Trait initializers -----------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions ipykernel/inprocess/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ async def poll(self, timeout=0):
assert timeout == 0
statistics = self.in_receive_stream.statistics()
return statistics.current_buffer_used != 0

def close(self):
pass
75 changes: 44 additions & 31 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

from __future__ import annotations

import atexit
import contextvars
import io
Expand All @@ -15,7 +17,7 @@
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import Event, Thread, local
from typing import Any, Callable, Optional
from typing import Any, Callable

import zmq
from anyio import create_task_group, run, sleep, to_thread
Expand All @@ -25,8 +27,8 @@
# Globals
# -----------------------------------------------------------------------------

MASTER = 0
CHILD = 1
_PARENT = 0
_CHILD = 1

PIPE_BUFFER_SIZE = 1000

Expand Down Expand Up @@ -87,9 +89,16 @@ def __init__(self, socket, pipe=False):
Whether this process should listen for IOPub messages
piped from subprocesses.
"""
self.socket = socket
# ensure all of our sockets as sync zmq.Sockets
# don't create async wrappers until we are within the appropriate coroutines
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
if self.socket.context is None:
# bug in pyzmq, shadow socket doesn't always inherit context attribute
self.socket.context = socket.context # type:ignore[unreachable]
self._context = socket.context

self.background_socket = BackgroundSocket(self)
self._master_pid = os.getpid()
self._main_pid = os.getpid()
self._pipe_flag = pipe
if pipe:
self._setup_pipe_in()
Expand All @@ -106,8 +115,7 @@ def __init__(self, socket, pipe=False):

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
ctx = self.socket.context
self._pipe_in0 = ctx.socket(zmq.PULL)
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in0.linger = 0

_uuid = b2a_hex(os.urandom(16)).decode("ascii")
Expand Down Expand Up @@ -141,8 +149,8 @@ def _event_pipe(self):
event_pipe = self._local.event_pipe
except AttributeError:
# new thread, new event pipe
ctx = zmq.Context(self.socket.context)
event_pipe = ctx.socket(zmq.PUSH)
# create sync base socket
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
event_pipe.linger = 0
event_pipe.connect(self._event_interface)
self._local.event_pipe = event_pipe
Expand All @@ -161,9 +169,11 @@ async def _handle_event(self):
Whenever *an* event arrives on the event stream,
*all* waiting events are processed in order.
"""
# create async wrapper within coroutine
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
try:
while True:
await self._pipe_in0.recv()
await pipe_in.recv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
Expand All @@ -177,12 +187,12 @@ async def _handle_event(self):

def _setup_pipe_in(self):
"""setup listening pipe for IOPub from forked subprocesses"""
ctx = self.socket.context
ctx = self._context

# use UUID to authenticate pipe messages
self._pipe_uuid = os.urandom(16)

self._pipe_in1 = ctx.socket(zmq.PULL)
self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in1.linger = 0

try:
Expand All @@ -199,6 +209,8 @@ def _setup_pipe_in(self):

async def _handle_pipe_msgs(self):
"""handle pipe messages from a subprocess"""
# create async wrapper within coroutine
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
minrk marked this conversation as resolved.
Show resolved Hide resolved
try:
while True:
await self._handle_pipe_msg()
Expand All @@ -209,8 +221,8 @@ async def _handle_pipe_msgs(self):

async def _handle_pipe_msg(self, msg=None):
"""handle a pipe message from a subprocess"""
msg = msg or await self._pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_master_process():
msg = msg or await self._async_pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_main_process():
return
if msg[0] != self._pipe_uuid:
print("Bad pipe message: %s", msg, file=sys.__stderr__)
Expand All @@ -225,14 +237,14 @@ def _setup_pipe_out(self):
pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port)
return ctx, pipe_out

def _is_master_process(self):
return os.getpid() == self._master_pid
def _is_main_process(self):
return os.getpid() == self._main_pid

def _check_mp_mode(self):
"""check for forks, and switch to zmq pipeline if necessary"""
if not self._pipe_flag or self._is_master_process():
return MASTER
return CHILD
if not self._pipe_flag or self._is_main_process():
return _PARENT
return _CHILD

def start(self):
"""Start the IOPub thread"""
Expand Down Expand Up @@ -265,7 +277,8 @@ def close(self):
self._pipe_in0.close()
if self._pipe_flag:
self._pipe_in1.close()
self.socket.close()
if self.socket is not None:
self.socket.close()
self.socket = None

@property
Expand Down Expand Up @@ -301,12 +314,12 @@ def _really_send(self, msg, *args, **kwargs):
return

mp_mode = self._check_mp_mode()

if mp_mode != CHILD:
# we are master, do a regular send
if mp_mode != _CHILD:
# we are the main parent process, do a regular send
assert self.socket is not None
self.socket.send_multipart(msg, *args, **kwargs)
else:
# we are a child, pipe to master
# we are a child, pipe to parent process
# new context/socket for every pipe-out
# since forks don't teardown politely, use ctx.term to ensure send has completed
ctx, pipe_out = self._setup_pipe_out()
Expand Down Expand Up @@ -379,7 +392,7 @@ class OutStream(TextIOBase):
flush_interval = 0.2
topic = None
encoding = "UTF-8"
_exc: Optional[Any] = None
_exc: Any = None

def fileno(self):
"""
Expand Down Expand Up @@ -477,7 +490,7 @@ def __init__(
self._thread_to_parent = {}
self._thread_to_parent_header = {}
self._parent_header_global = {}
self._master_pid = os.getpid()
self._main_pid = os.getpid()
self._flush_pending = False
self._subprocess_flush_pending = False
self._buffer_lock = threading.RLock()
Expand Down Expand Up @@ -569,8 +582,8 @@ def _setup_stream_redirects(self, name):
self.watch_fd_thread.daemon = True
self.watch_fd_thread.start()

def _is_master_process(self):
return os.getpid() == self._master_pid
def _is_main_process(self):
return os.getpid() == self._main_pid

def set_parent(self, parent):
"""Set the parent header."""
Expand Down Expand Up @@ -674,7 +687,7 @@ def _flush(self):
ident=self.topic,
)

def write(self, string: str) -> Optional[int]: # type:ignore[override]
def write(self, string: str) -> int:
"""Write to current stream after encoding if necessary

Returns
Expand All @@ -700,15 +713,15 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
msg = "I/O operation on closed file"
raise ValueError(msg)

is_child = not self._is_master_process()
is_child = not self._is_main_process()
# only touch the buffer in the IO thread to avoid races
with self._buffer_lock:
self._buffers[frozenset(parent.items())].write(string)
if is_child:
# mp.Pool cannot be trusted to flush promptly (or ever),
# and this helps.
if self._subprocess_flush_pending:
return None
return 0
self._subprocess_flush_pending = True
# We can not rely on self._io_loop.call_later from a subprocess
self.pub_thread.schedule(self._flush)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import zmq.asyncio
from jupyter_client.session import Session

from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream
from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream


@pytest.fixture()
Expand Down Expand Up @@ -73,7 +73,7 @@ async def test_io_thread(anyio_backend, iopub_thread):
ctx1, pipe = thread._setup_pipe_out()
pipe.close()
thread._pipe_in1.close()
thread._check_mp_mode = lambda: MASTER
thread._check_mp_mode = lambda: _PARENT
thread._really_send([b"hi"])
ctx1.destroy()
thread.stop()
Expand Down
24 changes: 12 additions & 12 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)


def _check_master(kc, expected=True, stream="stdout"):
def _check_main(kc, expected=True, stream="stdout"):
execute(kc=kc, code="import sys")
flush_channels(kc)
msg_id, content = execute(kc=kc, code="print(sys.%s._is_master_process())" % stream)
msg_id, content = execute(kc=kc, code="print(sys.%s._is_main_process())" % stream)
stdout, stderr = assemble_output(kc.get_iopub_msg)
assert stdout.strip() == repr(expected)

Expand All @@ -56,7 +56,7 @@ def test_simple_print():
stdout, stderr = assemble_output(kc.get_iopub_msg)
assert stdout == "hi\n"
assert stderr == ""
_check_master(kc, expected=True)
_check_main(kc, expected=True)


def test_print_to_correct_cell_from_thread():
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_capture_fd():
stdout, stderr = assemble_output(iopub)
assert stdout == "capsys\n"
assert stderr == ""
_check_master(kc, expected=True)
_check_main(kc, expected=True)


@pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing")
Expand All @@ -182,7 +182,7 @@ def test_subprocess_peek_at_stream_fileno():
stdout, stderr = assemble_output(iopub)
assert stdout == "CAP1\nCAP2\n"
assert stderr == ""
_check_master(kc, expected=True)
_check_main(kc, expected=True)


def test_sys_path():
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_sys_path_profile_dir():
def test_subprocess_print():
"""printing from forked mp.Process"""
with new_kernel() as kc:
_check_master(kc, expected=True)
_check_main(kc, expected=True)
flush_channels(kc)
np = 5
code = "\n".join(
Expand All @@ -238,8 +238,8 @@ def test_subprocess_print():
for n in range(np):
assert stdout.count(str(n)) == 1, stdout
assert stderr == ""
_check_master(kc, expected=True)
_check_master(kc, expected=True, stream="stderr")
_check_main(kc, expected=True)
_check_main(kc, expected=True, stream="stderr")


@flaky(max_runs=3)
Expand All @@ -261,8 +261,8 @@ def test_subprocess_noprint():
assert stdout == ""
assert stderr == ""

_check_master(kc, expected=True)
_check_master(kc, expected=True, stream="stderr")
_check_main(kc, expected=True)
_check_main(kc, expected=True, stream="stderr")


@flaky(max_runs=3)
Expand All @@ -287,8 +287,8 @@ def test_subprocess_error():
assert stdout == ""
assert "ValueError" in stderr

_check_master(kc, expected=True)
_check_master(kc, expected=True, stream="stderr")
_check_main(kc, expected=True)
_check_main(kc, expected=True, stream="stderr")


# raw_input tests
Expand Down
Loading