From 9ae69406bfe19694e54ac23cddb9da22394c6ecc Mon Sep 17 00:00:00 2001 From: Min RK Date: Mon, 21 Oct 2024 14:53:16 +0200 Subject: [PATCH 1/4] fix mixture of sync/async sockets in IOPubThread all sockets are explicitly sync until/except we are in the coroutines that will await them - consistent behavior of send for child pipe and main process sockets - avoids unsafe assumption that send is greedy on async sockets - avoids potential issues creating async objects in one thread, then using them in another in a different event loop - always creates/uses the right types, regardless of input socket - address some typing lint --- ipykernel/inprocess/ipkernel.py | 4 +- ipykernel/inprocess/socket.py | 3 ++ ipykernel/iostream.py | 74 +++++++++++++++++++-------------- 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index 114e231d..c6f8c612 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -6,6 +6,7 @@ import logging import sys from contextlib import contextmanager +from typing import cast from anyio import TASK_STATUS_IGNORED from anyio.abc import TaskStatus @@ -146,7 +147,8 @@ def callback(msg): assert frontend is not None frontend.iopub_channel.call_handlers(msg) - self.iopub_thread.socket.on_recv = callback + iopub_socket = cast(DummySocket, self.iopub_thread.socket) + iopub_socket.on_recv = callback # ------ Trait initializers ----------------------------------------------- diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index edc77c28..5a2e0008 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -63,3 +63,6 @@ async def poll(self, timeout=0): assert timeout == 0 statistics = self.in_receive_stream.statistics() return statistics.current_buffer_used != 0 + + def close(self): + pass diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index beca44b1..0d29b38c 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -3,6 +3,8 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import atexit import contextvars import io @@ -15,7 +17,7 @@ from collections import defaultdict, deque from io import StringIO, TextIOBase from threading import Event, Thread, local -from typing import Any, Callable, Deque, Dict, Optional +from typing import Any, Callable import zmq from anyio import create_task_group, run, sleep, to_thread @@ -25,8 +27,8 @@ # Globals # ----------------------------------------------------------------------------- -MASTER = 0 -CHILD = 1 +_PARENT = 0 +_CHILD = 1 PIPE_BUFFER_SIZE = 1000 @@ -87,15 +89,19 @@ def __init__(self, socket, pipe=False): Whether this process should listen for IOPub messages piped from subprocesses. """ - self.socket = socket + # ensure all of our sockets as sync zmq.Sockets + # don't create async wrappers until we are within the appropriate coroutines + self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket) + self._sync_context: zmq.Context[zmq.Socket[bytes]] = zmq.Context(socket.context) + self.background_socket = BackgroundSocket(self) - self._master_pid = os.getpid() + self._main_pid = os.getpid() self._pipe_flag = pipe if pipe: self._setup_pipe_in() self._local = threading.local() - self._events: Deque[Callable[..., Any]] = deque() - self._event_pipes: Dict[threading.Thread, Any] = {} + self._events: deque[Callable[..., Any]] = deque() + self._event_pipes: dict[threading.Thread, Any] = {} self._event_pipe_gc_lock: threading.Lock = threading.Lock() self._event_pipe_gc_seconds: float = 10 self._setup_event_pipe() @@ -106,7 +112,7 @@ def __init__(self, socket, pipe=False): def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" - ctx = self.socket.context + ctx = self._sync_context self._pipe_in0 = ctx.socket(zmq.PULL) self._pipe_in0.linger = 0 @@ -141,8 +147,7 @@ def _event_pipe(self): event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe - ctx = zmq.Context(self.socket.context) - event_pipe = ctx.socket(zmq.PUSH) + event_pipe = self._sync_context.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe @@ -161,9 +166,11 @@ async def _handle_event(self): Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ + # create async wrapper within coroutine + pipe_in = zmq.asyncio.Socket.shadow(self._pipe_in0) try: while True: - await self._pipe_in0.recv() + await pipe_in.recv() # freeze event count so new writes don't extend the queue # while we are processing n_events = len(self._events) @@ -177,7 +184,7 @@ async def _handle_event(self): def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" - ctx = self.socket.context + ctx = self._sync_context # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) @@ -199,6 +206,8 @@ def _setup_pipe_in(self): async def _handle_pipe_msgs(self): """handle pipe messages from a subprocess""" + # create async wrapper within coroutine + self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1) try: while True: await self._handle_pipe_msg() @@ -209,8 +218,8 @@ async def _handle_pipe_msgs(self): async def _handle_pipe_msg(self, msg=None): """handle a pipe message from a subprocess""" - msg = msg or await self._pipe_in1.recv_multipart() - if not self._pipe_flag or not self._is_master_process(): + msg = msg or await self._async_pipe_in1.recv_multipart() + if not self._pipe_flag or not self._is_main_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) @@ -225,14 +234,14 @@ def _setup_pipe_out(self): pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) return ctx, pipe_out - def _is_master_process(self): - return os.getpid() == self._master_pid + def _is_main_process(self): + return os.getpid() == self._main_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" - if not self._pipe_flag or self._is_master_process(): - return MASTER - return CHILD + if not self._pipe_flag or self._is_main_process(): + return _PARENT + return _CHILD def start(self): """Start the IOPub thread""" @@ -265,7 +274,8 @@ def close(self): self._pipe_in0.close() if self._pipe_flag: self._pipe_in1.close() - self.socket.close() + if self.socket is not None: + self.socket.close() self.socket = None @property @@ -301,12 +311,12 @@ def _really_send(self, msg, *args, **kwargs): return mp_mode = self._check_mp_mode() - - if mp_mode != CHILD: - # we are master, do a regular send + if mp_mode != _CHILD: + # we are the main parent process, do a regular send + assert self.socket is not None self.socket.send_multipart(msg, *args, **kwargs) else: - # we are a child, pipe to master + # we are a child, pipe to parent process # new context/socket for every pipe-out # since forks don't teardown politely, use ctx.term to ensure send has completed ctx, pipe_out = self._setup_pipe_out() @@ -379,7 +389,7 @@ class OutStream(TextIOBase): flush_interval = 0.2 topic = None encoding = "UTF-8" - _exc: Optional[Any] = None + _exc: Any = None def fileno(self): """ @@ -470,14 +480,14 @@ def __init__( self.pub_thread = pub_thread self.name = name self.topic = b"stream." + name.encode() - self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar( + self._parent_header: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( "parent_header" ) self._parent_header.set({}) self._thread_to_parent = {} self._thread_to_parent_header = {} self._parent_header_global = {} - self._master_pid = os.getpid() + self._main_pid = os.getpid() self._flush_pending = False self._subprocess_flush_pending = False self._buffer_lock = threading.RLock() @@ -569,8 +579,8 @@ def _setup_stream_redirects(self, name): self.watch_fd_thread.daemon = True self.watch_fd_thread.start() - def _is_master_process(self): - return os.getpid() == self._master_pid + def _is_main_process(self): + return os.getpid() == self._main_pid def set_parent(self, parent): """Set the parent header.""" @@ -674,7 +684,7 @@ def _flush(self): ident=self.topic, ) - def write(self, string: str) -> Optional[int]: # type:ignore[override] + def write(self, string: str) -> int: """Write to current stream after encoding if necessary Returns @@ -700,7 +710,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override] msg = "I/O operation on closed file" raise ValueError(msg) - is_child = not self._is_master_process() + is_child = not self._is_main_process() # only touch the buffer in the IO thread to avoid races with self._buffer_lock: self._buffers[frozenset(parent.items())].write(string) @@ -708,7 +718,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override] # mp.Pool cannot be trusted to flush promptly (or ever), # and this helps. if self._subprocess_flush_pending: - return None + return 0 self._subprocess_flush_pending = True # We can not rely on self._io_loop.call_later from a subprocess self.pub_thread.schedule(self._flush) From 3b1188a04c5c4dc9ee2655d842fdb45dc0d3df69 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 22 Oct 2024 11:05:40 +0200 Subject: [PATCH 2/4] consistent shadow socket creation --- ipykernel/iostream.py | 2 +- tests/test_io.py | 4 ++-- tests/test_kernel.py | 24 ++++++++++++------------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 0d29b38c..522478c7 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -167,7 +167,7 @@ async def _handle_event(self): *all* waiting events are processed in order. """ # create async wrapper within coroutine - pipe_in = zmq.asyncio.Socket.shadow(self._pipe_in0) + pipe_in = zmq.asyncio.Socket(self._pipe_in0) try: while True: await pipe_in.recv() diff --git a/tests/test_io.py b/tests/test_io.py index e49bc276..e3ff2815 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -15,7 +15,7 @@ import zmq.asyncio from jupyter_client.session import Session -from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream +from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream @pytest.fixture() @@ -73,7 +73,7 @@ async def test_io_thread(anyio_backend, iopub_thread): ctx1, pipe = thread._setup_pipe_out() pipe.close() thread._pipe_in1.close() - thread._check_mp_mode = lambda: MASTER + thread._check_mp_mode = lambda: _PARENT thread._really_send([b"hi"]) ctx1.destroy() thread.stop() diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 88f02ae9..727a7a9c 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -32,10 +32,10 @@ ) -def _check_master(kc, expected=True, stream="stdout"): +def _check_main(kc, expected=True, stream="stdout"): execute(kc=kc, code="import sys") flush_channels(kc) - msg_id, content = execute(kc=kc, code="print(sys.%s._is_master_process())" % stream) + msg_id, content = execute(kc=kc, code="print(sys.%s._is_main_process())" % stream) stdout, stderr = assemble_output(kc.get_iopub_msg) assert stdout.strip() == repr(expected) @@ -56,7 +56,7 @@ def test_simple_print(): stdout, stderr = assemble_output(kc.get_iopub_msg) assert stdout == "hi\n" assert stderr == "" - _check_master(kc, expected=True) + _check_main(kc, expected=True) def test_print_to_correct_cell_from_thread(): @@ -168,7 +168,7 @@ def test_capture_fd(): stdout, stderr = assemble_output(iopub) assert stdout == "capsys\n" assert stderr == "" - _check_master(kc, expected=True) + _check_main(kc, expected=True) @pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing") @@ -182,7 +182,7 @@ def test_subprocess_peek_at_stream_fileno(): stdout, stderr = assemble_output(iopub) assert stdout == "CAP1\nCAP2\n" assert stderr == "" - _check_master(kc, expected=True) + _check_main(kc, expected=True) def test_sys_path(): @@ -218,7 +218,7 @@ def test_sys_path_profile_dir(): def test_subprocess_print(): """printing from forked mp.Process""" with new_kernel() as kc: - _check_master(kc, expected=True) + _check_main(kc, expected=True) flush_channels(kc) np = 5 code = "\n".join( @@ -238,8 +238,8 @@ def test_subprocess_print(): for n in range(np): assert stdout.count(str(n)) == 1, stdout assert stderr == "" - _check_master(kc, expected=True) - _check_master(kc, expected=True, stream="stderr") + _check_main(kc, expected=True) + _check_main(kc, expected=True, stream="stderr") @flaky(max_runs=3) @@ -261,8 +261,8 @@ def test_subprocess_noprint(): assert stdout == "" assert stderr == "" - _check_master(kc, expected=True) - _check_master(kc, expected=True, stream="stderr") + _check_main(kc, expected=True) + _check_main(kc, expected=True, stream="stderr") @flaky(max_runs=3) @@ -287,8 +287,8 @@ def test_subprocess_error(): assert stdout == "" assert "ValueError" in stderr - _check_master(kc, expected=True) - _check_master(kc, expected=True, stream="stderr") + _check_main(kc, expected=True) + _check_main(kc, expected=True, stream="stderr") # raw_input tests From 64ff5d4dddf8ea8345d51b31927b3a6c01703ea8 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 22 Oct 2024 11:32:50 +0200 Subject: [PATCH 3/4] avoid shadowing context instead, keep same context but use `socket_class` kwarg to specify socket classes shadow context prevents cleanup of untracked sockets via ctx.destroy because it disconnects socket bookkeeping --- ipykernel/iostream.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 522478c7..bcfc9086 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -92,7 +92,10 @@ def __init__(self, socket, pipe=False): # ensure all of our sockets as sync zmq.Sockets # don't create async wrappers until we are within the appropriate coroutines self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket) - self._sync_context: zmq.Context[zmq.Socket[bytes]] = zmq.Context(socket.context) + if self.socket.context is None: + # bug in pyzmq, shadow socket doesn't always inherit context attribute + self.socket.context = socket.context + self._context = socket.context self.background_socket = BackgroundSocket(self) self._main_pid = os.getpid() @@ -112,8 +115,7 @@ def __init__(self, socket, pipe=False): def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" - ctx = self._sync_context - self._pipe_in0 = ctx.socket(zmq.PULL) + self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket) self._pipe_in0.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode("ascii") @@ -147,7 +149,8 @@ def _event_pipe(self): event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe - event_pipe = self._sync_context.socket(zmq.PUSH) + # create sync base socket + event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe @@ -184,12 +187,12 @@ async def _handle_event(self): def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" - ctx = self._sync_context + ctx = self._context # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) - self._pipe_in1 = ctx.socket(zmq.PULL) + self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket) self._pipe_in1.linger = 0 try: From 53b2911dbc103d120dbde08fad3e3f9b1f1643e6 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 22 Oct 2024 12:53:18 +0200 Subject: [PATCH 4/4] unhelpful linter --- ipykernel/iostream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index bcfc9086..d8171017 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -94,7 +94,7 @@ def __init__(self, socket, pipe=False): self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket) if self.socket.context is None: # bug in pyzmq, shadow socket doesn't always inherit context attribute - self.socket.context = socket.context + self.socket.context = socket.context # type:ignore[unreachable] self._context = socket.context self.background_socket = BackgroundSocket(self)