diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index 2f462af4..394f52a4 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -220,7 +220,7 @@ def init_poller(self): # PID 1 (init) is special and will never go away, # only be reassigned. # Parent polling doesn't work if ppid == 1 to start with. - self.poller = ParentPollerUnix() + self.poller = ParentPollerUnix(parent_pid=self.parent_handle) def _try_bind_socket(self, s, port): iface = f"{self.transport}://{self.ip}" diff --git a/ipykernel/parentpoller.py b/ipykernel/parentpoller.py index a6d9c753..895a785c 100644 --- a/ipykernel/parentpoller.py +++ b/ipykernel/parentpoller.py @@ -22,9 +22,17 @@ class ParentPollerUnix(Thread): when the parent process no longer exists. """ - def __init__(self): - """Initialize the poller.""" + def __init__(self, parent_pid=0): + """Initialize the poller. + + Parameters + ---------- + parent_handle : int, optional + If provided, the program will terminate immediately when + process parent is no longer this original parent. + """ super().__init__() + self.parent_pid = parent_pid self.daemon = True def run(self): @@ -32,9 +40,23 @@ def run(self): # We cannot use os.waitpid because it works only for child processes. from errno import EINTR + # before start, check if the passed-in parent pid is valid + original_ppid = os.getppid() + if original_ppid != self.parent_pid: + self.parent_pid = 0 + + get_logger().debug( + "%s: poll for parent change with original parent pid=%d", + type(self).__name__, + self.parent_pid, + ) + while True: try: - if os.getppid() == 1: + ppid = os.getppid() + parent_is_init = not self.parent_pid and ppid == 1 + parent_has_changed = self.parent_pid and ppid != self.parent_pid + if parent_is_init or parent_has_changed: get_logger().warning("Parent appears to have exited, shutting down.") os._exit(1) time.sleep(1.0) diff --git a/tests/test_parentpoller.py b/tests/test_parentpoller.py index 97cd8044..716c9e8f 100644 --- a/tests/test_parentpoller.py +++ b/tests/test_parentpoller.py @@ -9,7 +9,7 @@ @pytest.mark.skipif(os.name == "nt", reason="only works on posix") -def test_parent_poller_unix(): +def test_parent_poller_unix_to_pid1(): poller = ParentPollerUnix() with mock.patch("os.getppid", lambda: 1): # noqa: PT008 @@ -27,6 +27,22 @@ def mock_getppid(): poller.run() +@pytest.mark.skipif(os.name == "nt", reason="only works on posix") +def test_parent_poller_unix_reparent_not_pid1(): + parent_pid = 221 + parent_pids = iter([parent_pid, parent_pid - 1]) + + poller = ParentPollerUnix(parent_pid=parent_pid) + + with mock.patch("os.getppid", lambda: next(parent_pids)): # noqa: PT008 + + def exit_mock(*args): + sys.exit(1) + + with mock.patch("os._exit", exit_mock), pytest.raises(SystemExit): + poller.run() + + @pytest.mark.skipif(os.name != "nt", reason="only works on windows") def test_parent_poller_windows(): poller = ParentPollerWindows(interrupt_handle=1)