From 53f0039e27813c25e23881fa47ff264faf017334 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 11 Oct 2024 14:26:30 +0200 Subject: [PATCH] Improve clean_exception --- distributed/core.py | 3 +++ distributed/protocol/pickle.py | 4 ++-- distributed/tests/test_core.py | 11 +++++++++++ distributed/tests/test_steal.py | 2 ++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index a4bb031c12..69ea049677 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -6,6 +6,7 @@ import logging import math import os +import pickle import sys import tempfile import threading @@ -1706,6 +1707,8 @@ def clean_exception( if isinstance(exception, (bytes, bytearray)): try: exception = protocol.pickle.loads(exception) + except pickle.UnpicklingError as e: + exception = e except Exception: exception = Exception(exception) elif isinstance(exception, str): diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 0b0988df8e..1dd46b646f 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -77,7 +77,7 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): result = cloudpickle.dumps(x, **dump_kwargs) except Exception as e: logger.exception("Failed to serialize %s.", x) - raise pickle.PicklingError("Failed to serialize") from e + raise pickle.PicklingError("Failed to serialize", x, buffers) from e if buffer_callback is not None: for b in buffers: buffer_callback(b) @@ -94,4 +94,4 @@ def loads(x, *, buffers=()): raise except Exception as e: logger.info("Failed to deserialize %s", x[:10000], exc_info=True) - raise pickle.UnpicklingError("Failed to deserialize") from e + raise pickle.UnpicklingError("Failed to deserialize", x, buffers) from e diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 6f4836ae08..fdcad7b505 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -4,6 +4,7 @@ import contextlib import logging import os +import pickle import random import socket import sys @@ -167,6 +168,16 @@ async def test_server_raises_on_blocked_handlers(): await comm.close() +@gen_test() +async def test_unreadable_exception_in_clean_exception(): + pickled_ex = b"\x80\x04\x954\x00\x00\x00\x00\x00\x00\x00\x8c\x08__main__\x94\x8c\x14SomeUnknownException\x94\x93\x94\x8c\x08some arg\x94\x85\x94R\x94." + ex_type, ex, tb = clean_exception(pickled_ex) + assert ex_type == pickle.UnpicklingError + assert isinstance(ex, pickle.UnpicklingError) + assert ex.args[1] == pickled_ex + assert tb is None + + class MyServer(Server): default_port = 8756 diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3976857c9e..8f62a2bca3 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -70,6 +70,8 @@ async def test_work_stealing(c, s, a, b): await wait(futures) assert len(a.data) > 10 assert len(b.data) > 10 + del futures + print("Hi") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)