diff --git a/distributed/core.py b/distributed/core.py index a4bb031c12..69ea049677 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -6,6 +6,7 @@ import logging import math import os +import pickle import sys import tempfile import threading @@ -1706,6 +1707,8 @@ def clean_exception( if isinstance(exception, (bytes, bytearray)): try: exception = protocol.pickle.loads(exception) + except pickle.UnpicklingError as e: + exception = e except Exception: exception = Exception(exception) elif isinstance(exception, str): diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index cb724af012..1dd46b646f 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -75,9 +75,9 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): try: buffers.clear() result = cloudpickle.dumps(x, **dump_kwargs) - except Exception: + except Exception as e: logger.exception("Failed to serialize %s.", x) - raise + raise pickle.PicklingError("Failed to serialize", x, buffers) from e if buffer_callback is not None: for b in buffers: buffer_callback(b) @@ -90,6 +90,8 @@ def loads(x, *, buffers=()): return pickle.loads(x, buffers=buffers) else: return pickle.loads(x) - except Exception: - logger.info("Failed to deserialize %s", x[:10000], exc_info=True) + except EOFError: raise + except Exception as e: + logger.info("Failed to deserialize %s", x[:10000], exc_info=True) + raise pickle.UnpicklingError("Failed to deserialize", x, buffers) from e diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 3d9ef38de7..b5fd563206 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -15,7 +15,7 @@ from distributed.protocol import deserialize, serialize from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads from distributed.protocol.serialize import dask_deserialize, dask_serialize -from distributed.utils_test import popen, save_sys_modules +from distributed.utils_test import popen, raises_with_cause, save_sys_modules class MemoryviewHolder: @@ -231,7 +231,7 @@ def _deserialize_nopickle(header, frames): def test_allow_pickle_if_registered_in_dask_serialize(): - with pytest.raises(TypeError, match="nope"): + with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"): dumps(NoPickle()) dask_serialize.register(NoPickle)(_serialize_nopickle) @@ -251,9 +251,9 @@ def __init__(self) -> None: def test_nopickle_nested(): nested_obj = [NoPickle()] - with pytest.raises(TypeError, match="nope"): + with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"): dumps(nested_obj) - with pytest.raises(TypeError, match="nope"): + with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"): dumps(NestedNoPickle()) dask_serialize.register(NoPickle)(_serialize_nopickle) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0016294e70..fe86dcf927 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5100,7 +5100,14 @@ def __setstate__(self, state): future = c.submit(identity, Foo()) await wait(future) assert future.status == "error" - with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): + with raises_with_cause( + RuntimeError, + "deserialization", + pickle.UnpicklingError, + "deserialize", + MyException, + "hello", + ): await future futures = c.map(inc, range(10)) @@ -5125,7 +5132,14 @@ def __call__(self, *args): future = c.submit(Foo(), 1) await wait(future) assert future.status == "error" - with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): + with raises_with_cause( + RuntimeError, + "deserialization", + pickle.UnpicklingError, + "deserialize", + MyException, + "hello", + ): await future futures = c.map(inc, range(10)) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 6f4836ae08..fdcad7b505 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -4,6 +4,7 @@ import contextlib import logging import os +import pickle import random import socket import sys @@ -167,6 +168,16 @@ async def test_server_raises_on_blocked_handlers(): await comm.close() +@gen_test() +async def test_unreadable_exception_in_clean_exception(): + pickled_ex = b"\x80\x04\x954\x00\x00\x00\x00\x00\x00\x00\x8c\x08__main__\x94\x8c\x14SomeUnknownException\x94\x93\x94\x8c\x08some arg\x94\x85\x94R\x94." + ex_type, ex, tb = clean_exception(pickled_ex) + assert ex_type == pickle.UnpicklingError + assert isinstance(ex, pickle.UnpicklingError) + assert ex.args[1] == pickled_ex + assert tb is None + + class MyServer(Server): default_port = 8756 diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index 604f7b98c7..7372a6b43e 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -222,7 +222,7 @@ def test_spillbuffer_fail_to_serialize(tmp_path): with pytest.raises(TypeError, match="Failed to pickle 'a'") as e: with captured_logger("distributed.spill") as logs_bad_key: buf["a"] = a - assert isinstance(e.value.__cause__.__cause__, MyError) + assert isinstance(e.value.__cause__.__cause__.__cause__, MyError) # spill.py must remain silent because we're already logging in worker.py assert not logs_bad_key.getvalue() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3976857c9e..8f62a2bca3 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -70,6 +70,8 @@ async def test_work_stealing(c, s, a, b): await wait(futures) assert len(a.data) > 10 assert len(b.data) > 10 + del futures + print("Hi") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0a2fed3219..f2949b6518 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -6,6 +6,7 @@ import itertools import logging import os +import pickle import random import sys import tempfile @@ -38,6 +39,7 @@ get_client, get_worker, profile, + protocol, wait, ) from distributed.comm.registry import backends @@ -46,7 +48,6 @@ from distributed.core import CommClosedError, Status, rpc from distributed.diagnostics.plugin import ForwardOutput from distributed.metrics import time -from distributed.protocol import pickle from distributed.scheduler import KilledWorker, Scheduler from distributed.utils import get_mp_context, wait_for from distributed.utils_test import ( @@ -509,13 +510,15 @@ async def test_plugin_internal_exception(): with raises_with_cause( RuntimeError, "Worker failed to start", + pickle.UnpicklingError, + "deserialize", UnicodeDecodeError, - match_cause="codec can't decode", + "codec can't decode", ): async with Worker( s.address, plugins={ - b"corrupting pickle" + pickle.dumps(lambda: None), + b"corrupting pickle" + protocol.pickle.dumps(lambda: None), }, ) as w: pass diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 4a222861d8..1f3b4be72d 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -176,7 +176,7 @@ async def test_fail_to_pickle_execute_1(c, s, a, b): with pytest.raises(TypeError, match="Failed to pickle 'x'") as e: await x - assert isinstance(e.value.__cause__.__cause__, CustomError) + assert isinstance(e.value.__cause__.__cause__.__cause__, CustomError) await assert_basic_futures(c)