From 84955484ec1636ee4c7611471d20df2016b5cb57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krassowski?= <5832902+krassowski@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:41:20 +0000 Subject: [PATCH] Make outputs go to correct cell when generated in threads/asyncio (#1186) Co-authored-by: Steven Silvester --- ipykernel/iostream.py | 105 ++++++++++++++++++++++++++-------------- ipykernel/ipkernel.py | 94 +++++++++++++++++++++++++++++++++++ ipykernel/kernelbase.py | 8 +++ tests/test_kernel.py | 100 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 271 insertions(+), 36 deletions(-) diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 0bbdbe27..257b5c80 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -5,6 +5,7 @@ import asyncio import atexit +import contextvars import io import os import sys @@ -12,7 +13,7 @@ import traceback import warnings from binascii import b2a_hex -from collections import deque +from collections import defaultdict, deque from io import StringIO, TextIOBase from threading import local from typing import Any, Callable, Deque, Dict, Optional @@ -412,7 +413,7 @@ def __init__( name : str {'stderr', 'stdout'} the name of the standard stream to replace pipe : object - the pip object + the pipe object echo : bool whether to echo output watchfd : bool (default, True) @@ -446,13 +447,19 @@ def __init__( self.pub_thread = pub_thread self.name = name self.topic = b"stream." + name.encode() - self.parent_header = {} + self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar( + "parent_header" + ) + self._parent_header.set({}) + self._thread_to_parent = {} + self._thread_to_parent_header = {} + self._parent_header_global = {} self._master_pid = os.getpid() self._flush_pending = False self._subprocess_flush_pending = False self._io_loop = pub_thread.io_loop self._buffer_lock = threading.RLock() - self._buffer = StringIO() + self._buffers = defaultdict(StringIO) self.echo = None self._isatty = bool(isatty) self._should_watch = False @@ -495,6 +502,30 @@ def __init__( msg = "echo argument must be a file-like object" raise ValueError(msg) + @property + def parent_header(self): + try: + # asyncio-specific + return self._parent_header.get() + except LookupError: + try: + # thread-specific + identity = threading.current_thread().ident + # retrieve the outermost (oldest ancestor, + # discounting the kernel thread) thread identity + while identity in self._thread_to_parent: + identity = self._thread_to_parent[identity] + # use the header of the oldest ancestor + return self._thread_to_parent_header[identity] + except KeyError: + # global (fallback) + return self._parent_header_global + + @parent_header.setter + def parent_header(self, value): + self._parent_header_global = value + return self._parent_header.set(value) + def isatty(self): """Return a bool indicating whether this is an 'interactive' stream. @@ -598,28 +629,28 @@ def _flush(self): if self.echo is not sys.__stderr__: print(f"Flush failed: {e}", file=sys.__stderr__) - data = self._flush_buffer() - if data: - # FIXME: this disables Session's fork-safe check, - # since pub_thread is itself fork-safe. - # There should be a better way to do this. - self.session.pid = os.getpid() - content = {"name": self.name, "text": data} - msg = self.session.msg("stream", content, parent=self.parent_header) - - # Each transform either returns a new - # message or None. If None is returned, - # the message has been 'used' and we return. - for hook in self._hooks: - msg = hook(msg) - if msg is None: - return - - self.session.send( - self.pub_thread, - msg, - ident=self.topic, - ) + for parent, data in self._flush_buffers(): + if data: + # FIXME: this disables Session's fork-safe check, + # since pub_thread is itself fork-safe. + # There should be a better way to do this. + self.session.pid = os.getpid() + content = {"name": self.name, "text": data} + msg = self.session.msg("stream", content, parent=parent) + + # Each transform either returns a new + # message or None. If None is returned, + # the message has been 'used' and we return. + for hook in self._hooks: + msg = hook(msg) + if msg is None: + return + + self.session.send( + self.pub_thread, + msg, + ident=self.topic, + ) def write(self, string: str) -> Optional[int]: # type:ignore[override] """Write to current stream after encoding if necessary @@ -630,6 +661,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override] number of items from input parameter written to stream. """ + parent = self.parent_header if not isinstance(string, str): msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable] @@ -649,7 +681,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override] is_child = not self._is_master_process() # only touch the buffer in the IO thread to avoid races with self._buffer_lock: - self._buffer.write(string) + self._buffers[frozenset(parent.items())].write(string) if is_child: # mp.Pool cannot be trusted to flush promptly (or ever), # and this helps. @@ -675,19 +707,20 @@ def writable(self): """Test whether the stream is writable.""" return True - def _flush_buffer(self): + def _flush_buffers(self): """clear the current buffer and return the current buffer data.""" - buf = self._rotate_buffer() - data = buf.getvalue() - buf.close() - return data + buffers = self._rotate_buffers() + for frozen_parent, buffer in buffers.items(): + data = buffer.getvalue() + buffer.close() + yield dict(frozen_parent), data - def _rotate_buffer(self): + def _rotate_buffers(self): """Returns the current buffer and replaces it with an empty buffer.""" with self._buffer_lock: - old_buffer = self._buffer - self._buffer = StringIO() - return old_buffer + old_buffers = self._buffers + self._buffers = defaultdict(StringIO) + return old_buffers @property def _hooks(self): diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index 2d0ec6ff..40d57945 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -2,6 +2,7 @@ import asyncio import builtins +import gc import getpass import os import signal @@ -14,6 +15,7 @@ import comm from IPython.core import release from IPython.utils.tokenutil import line_at_cursor, token_at_cursor +from jupyter_client.session import extract_header from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat from zmq.eventloop.zmqstream import ZMQStream @@ -22,6 +24,7 @@ from .compiler import XCachingCompiler from .debugger import Debugger, _is_debugpy_available from .eventloops import _use_appnope +from .iostream import OutStream from .kernelbase import Kernel as KernelBase from .kernelbase import _accepts_parameters from .zmqshell import ZMQInteractiveShell @@ -151,6 +154,14 @@ def __init__(self, **kwargs): appnope.nope() + self._new_threads_parent_header = {} + self._initialize_thread_hooks() + + if hasattr(gc, "callbacks"): + # while `gc.callbacks` exists since Python 3.3, pypy does not + # implement it even as of 3.9. + gc.callbacks.append(self._clean_thread_parent_frames) + help_links = List( [ { @@ -341,6 +352,12 @@ def set_sigint_result(): # restore the previous sigint handler signal.signal(signal.SIGINT, save_sigint) + async def execute_request(self, stream, ident, parent): + """Override for cell output - cell reconciliation.""" + parent_header = extract_header(parent) + self._associate_new_top_level_threads_with(parent_header) + await super().execute_request(stream, ident, parent) + async def do_execute( self, code, @@ -706,6 +723,83 @@ def do_clear(self): self.shell.reset(False) return dict(status="ok") + def _associate_new_top_level_threads_with(self, parent_header): + """Store the parent header to associate it with new top-level threads""" + self._new_threads_parent_header = parent_header + + def _initialize_thread_hooks(self): + """Store thread hierarchy and thread-parent_header associations.""" + stdout = self._stdout + stderr = self._stderr + kernel_thread_ident = threading.get_ident() + kernel = self + _threading_Thread_run = threading.Thread.run + _threading_Thread__init__ = threading.Thread.__init__ + + def run_closure(self: threading.Thread): + """Wrap the `threading.Thread.start` to intercept thread identity. + + This is needed because there is no "start" hook yet, but there + might be one in the future: https://bugs.python.org/issue14073 + + This is a no-op if the `self._stdout` and `self._stderr` are not + sub-classes of `OutStream`. + """ + + try: + parent = self._ipykernel_parent_thread_ident # type:ignore[attr-defined] + except AttributeError: + return + for stream in [stdout, stderr]: + if isinstance(stream, OutStream): + if parent == kernel_thread_ident: + stream._thread_to_parent_header[ + self.ident + ] = kernel._new_threads_parent_header + else: + stream._thread_to_parent[self.ident] = parent + _threading_Thread_run(self) + + def init_closure(self: threading.Thread, *args, **kwargs): + _threading_Thread__init__(self, *args, **kwargs) + self._ipykernel_parent_thread_ident = threading.get_ident() # type:ignore[attr-defined] + + threading.Thread.__init__ = init_closure # type:ignore[method-assign] + threading.Thread.run = run_closure # type:ignore[method-assign] + + def _clean_thread_parent_frames( + self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any] + ): + """Clean parent frames of threads which are no longer running. + This is meant to be invoked by garbage collector callback hook. + + The implementation enumerates the threads because there is no "exit" hook yet, + but there might be one in the future: https://bugs.python.org/issue14073 + + This is a no-op if the `self._stdout` and `self._stderr` are not + sub-classes of `OutStream`. + """ + # Only run before the garbage collector starts + if phase != "start": + return + active_threads = {thread.ident for thread in threading.enumerate()} + for stream in [self._stdout, self._stderr]: + if isinstance(stream, OutStream): + thread_to_parent_header = stream._thread_to_parent_header + for identity in list(thread_to_parent_header.keys()): + if identity not in active_threads: + try: + del thread_to_parent_header[identity] + except KeyError: + pass + thread_to_parent = stream._thread_to_parent + for identity in list(thread_to_parent.keys()): + if identity not in active_threads: + try: + del thread_to_parent[identity] + except KeyError: + pass + # This exists only for backwards compatibility - use IPythonKernel instead diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index f9eb2b94..79bca7b4 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -61,6 +61,7 @@ from ipykernel.jsonutil import json_clean from ._version import kernel_protocol_version +from .iostream import OutStream def _accepts_parameters(meth, param_names): @@ -272,6 +273,13 @@ def _parent_header(self): def __init__(self, **kwargs): """Initialize the kernel.""" super().__init__(**kwargs) + + # Kernel application may swap stdout and stderr to OutStream, + # which is the case in `IPKernelApp.init_io`, hence `sys.stdout` + # can already by different from TextIO at initialization time. + self._stdout: OutStream | t.TextIO = sys.stdout + self._stderr: OutStream | t.TextIO = sys.stderr + # Build dict of handlers for message types self.shell_handlers = {} for msg_type in self.msg_types: diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 07411bd1..31338896 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -58,6 +58,106 @@ def test_simple_print(): _check_master(kc, expected=True) +def test_print_to_correct_cell_from_thread(): + """should print to the cell that spawned the thread, not a subsequently run cell""" + iterations = 5 + interval = 0.25 + code = f"""\ + from threading import Thread + from time import sleep + + def thread_target(): + for i in range({iterations}): + print(i, end='', flush=True) + sleep({interval}) + + Thread(target=thread_target).start() + """ + with kernel() as kc: + thread_msg_id = kc.execute(code) + _ = kc.execute("pass") + + received = 0 + while received < iterations: + msg = kc.get_iopub_msg(timeout=interval * 2) + if msg["msg_type"] != "stream": + continue + content = msg["content"] + assert content["name"] == "stdout" + assert content["text"] == str(received) + # this is crucial as the parent header decides to which cell the output goes + assert msg["parent_header"]["msg_id"] == thread_msg_id + received += 1 + + +def test_print_to_correct_cell_from_child_thread(): + """should print to the cell that spawned the thread, not a subsequently run cell""" + iterations = 5 + interval = 0.25 + code = f"""\ + from threading import Thread + from time import sleep + + def child_target(): + for i in range({iterations}): + print(i, end='', flush=True) + sleep({interval}) + + def parent_target(): + sleep({interval}) + Thread(target=child_target).start() + + Thread(target=parent_target).start() + """ + with kernel() as kc: + thread_msg_id = kc.execute(code) + _ = kc.execute("pass") + + received = 0 + while received < iterations: + msg = kc.get_iopub_msg(timeout=interval * 2) + if msg["msg_type"] != "stream": + continue + content = msg["content"] + assert content["name"] == "stdout" + assert content["text"] == str(received) + # this is crucial as the parent header decides to which cell the output goes + assert msg["parent_header"]["msg_id"] == thread_msg_id + received += 1 + + +def test_print_to_correct_cell_from_asyncio(): + """should print to the cell that scheduled the task, not a subsequently run cell""" + iterations = 5 + interval = 0.25 + code = f"""\ + import asyncio + + async def async_task(): + for i in range({iterations}): + print(i, end='', flush=True) + await asyncio.sleep({interval}) + + loop = asyncio.get_event_loop() + loop.create_task(async_task()); + """ + with kernel() as kc: + thread_msg_id = kc.execute(code) + _ = kc.execute("pass") + + received = 0 + while received < iterations: + msg = kc.get_iopub_msg(timeout=interval * 2) + if msg["msg_type"] != "stream": + continue + content = msg["content"] + assert content["name"] == "stdout" + assert content["text"] == str(received) + # this is crucial as the parent header decides to which cell the output goes + assert msg["parent_header"]["msg_id"] == thread_msg_id + received += 1 + + @pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing") def test_capture_fd(): """simple print statement in kernel"""