From 18685db6d8efe4821373d4e5ea7234c5883ff466 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 10 Oct 2024 19:42:21 +0200 Subject: [PATCH 1/3] Raise explicit pickle errors --- distributed/protocol/pickle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index cb724af012..e0d1143de4 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -75,9 +75,9 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): try: buffers.clear() result = cloudpickle.dumps(x, **dump_kwargs) - except Exception: + except Exception as e: logger.exception("Failed to serialize %s.", x) - raise + raise pickle.PicklingError("Failed to serialize") from e if buffer_callback is not None: for b in buffers: buffer_callback(b) @@ -90,6 +90,6 @@ def loads(x, *, buffers=()): return pickle.loads(x, buffers=buffers) else: return pickle.loads(x) - except Exception: + except Exception as e: logger.info("Failed to deserialize %s", x[:10000], exc_info=True) - raise + raise pickle.UnpicklingError("Failed to deserialize") from e From d746d5062626a7c41511c9bd586d5917711468a5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 10 Oct 2024 20:41:27 +0200 Subject: [PATCH 2/3] Adjust tests --- distributed/protocol/pickle.py | 2 ++ distributed/protocol/tests/test_pickle.py | 8 ++++---- distributed/tests/test_client.py | 18 ++++++++++++++++-- distributed/tests/test_spill.py | 2 +- distributed/tests/test_worker.py | 9 ++++++--- distributed/tests/test_worker_memory.py | 2 +- 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index e0d1143de4..0b0988df8e 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -90,6 +90,8 @@ def loads(x, *, buffers=()): return pickle.loads(x, buffers=buffers) else: return pickle.loads(x) + except EOFError: + raise except Exception as e: logger.info("Failed to deserialize %s", x[:10000], exc_info=True) raise pickle.UnpicklingError("Failed to deserialize") from e diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 3d9ef38de7..b5fd563206 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -15,7 +15,7 @@ from distributed.protocol import deserialize, serialize from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads from distributed.protocol.serialize import dask_deserialize, dask_serialize -from distributed.utils_test import popen, save_sys_modules +from distributed.utils_test import popen, raises_with_cause, save_sys_modules class MemoryviewHolder: @@ -231,7 +231,7 @@ def _deserialize_nopickle(header, frames): def test_allow_pickle_if_registered_in_dask_serialize(): - with pytest.raises(TypeError, match="nope"): + with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"): dumps(NoPickle()) dask_serialize.register(NoPickle)(_serialize_nopickle) @@ -251,9 +251,9 @@ def __init__(self) -> None: def test_nopickle_nested(): nested_obj = [NoPickle()] - with pytest.raises(TypeError, match="nope"): + with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"): dumps(nested_obj) - with pytest.raises(TypeError, match="nope"): + with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"): dumps(NestedNoPickle()) dask_serialize.register(NoPickle)(_serialize_nopickle) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0016294e70..fe86dcf927 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5100,7 +5100,14 @@ def __setstate__(self, state): future = c.submit(identity, Foo()) await wait(future) assert future.status == "error" - with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): + with raises_with_cause( + RuntimeError, + "deserialization", + pickle.UnpicklingError, + "deserialize", + MyException, + "hello", + ): await future futures = c.map(inc, range(10)) @@ -5125,7 +5132,14 @@ def __call__(self, *args): future = c.submit(Foo(), 1) await wait(future) assert future.status == "error" - with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): + with raises_with_cause( + RuntimeError, + "deserialization", + pickle.UnpicklingError, + "deserialize", + MyException, + "hello", + ): await future futures = c.map(inc, range(10)) diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index 604f7b98c7..7372a6b43e 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -222,7 +222,7 @@ def test_spillbuffer_fail_to_serialize(tmp_path): with pytest.raises(TypeError, match="Failed to pickle 'a'") as e: with captured_logger("distributed.spill") as logs_bad_key: buf["a"] = a - assert isinstance(e.value.__cause__.__cause__, MyError) + assert isinstance(e.value.__cause__.__cause__.__cause__, MyError) # spill.py must remain silent because we're already logging in worker.py assert not logs_bad_key.getvalue() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0a2fed3219..f2949b6518 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -6,6 +6,7 @@ import itertools import logging import os +import pickle import random import sys import tempfile @@ -38,6 +39,7 @@ get_client, get_worker, profile, + protocol, wait, ) from distributed.comm.registry import backends @@ -46,7 +48,6 @@ from distributed.core import CommClosedError, Status, rpc from distributed.diagnostics.plugin import ForwardOutput from distributed.metrics import time -from distributed.protocol import pickle from distributed.scheduler import KilledWorker, Scheduler from distributed.utils import get_mp_context, wait_for from distributed.utils_test import ( @@ -509,13 +510,15 @@ async def test_plugin_internal_exception(): with raises_with_cause( RuntimeError, "Worker failed to start", + pickle.UnpicklingError, + "deserialize", UnicodeDecodeError, - match_cause="codec can't decode", + "codec can't decode", ): async with Worker( s.address, plugins={ - b"corrupting pickle" + pickle.dumps(lambda: None), + b"corrupting pickle" + protocol.pickle.dumps(lambda: None), }, ) as w: pass diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 4a222861d8..1f3b4be72d 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -176,7 +176,7 @@ async def test_fail_to_pickle_execute_1(c, s, a, b): with pytest.raises(TypeError, match="Failed to pickle 'x'") as e: await x - assert isinstance(e.value.__cause__.__cause__, CustomError) + assert isinstance(e.value.__cause__.__cause__.__cause__, CustomError) await assert_basic_futures(c) From 53f0039e27813c25e23881fa47ff264faf017334 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 11 Oct 2024 14:26:30 +0200 Subject: [PATCH 3/3] Improve clean_exception --- distributed/core.py | 3 +++ distributed/protocol/pickle.py | 4 ++-- distributed/tests/test_core.py | 11 +++++++++++ distributed/tests/test_steal.py | 2 ++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index a4bb031c12..69ea049677 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -6,6 +6,7 @@ import logging import math import os +import pickle import sys import tempfile import threading @@ -1706,6 +1707,8 @@ def clean_exception( if isinstance(exception, (bytes, bytearray)): try: exception = protocol.pickle.loads(exception) + except pickle.UnpicklingError as e: + exception = e except Exception: exception = Exception(exception) elif isinstance(exception, str): diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 0b0988df8e..1dd46b646f 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -77,7 +77,7 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): result = cloudpickle.dumps(x, **dump_kwargs) except Exception as e: logger.exception("Failed to serialize %s.", x) - raise pickle.PicklingError("Failed to serialize") from e + raise pickle.PicklingError("Failed to serialize", x, buffers) from e if buffer_callback is not None: for b in buffers: buffer_callback(b) @@ -94,4 +94,4 @@ def loads(x, *, buffers=()): raise except Exception as e: logger.info("Failed to deserialize %s", x[:10000], exc_info=True) - raise pickle.UnpicklingError("Failed to deserialize") from e + raise pickle.UnpicklingError("Failed to deserialize", x, buffers) from e diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 6f4836ae08..fdcad7b505 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -4,6 +4,7 @@ import contextlib import logging import os +import pickle import random import socket import sys @@ -167,6 +168,16 @@ async def test_server_raises_on_blocked_handlers(): await comm.close() +@gen_test() +async def test_unreadable_exception_in_clean_exception(): + pickled_ex = b"\x80\x04\x954\x00\x00\x00\x00\x00\x00\x00\x8c\x08__main__\x94\x8c\x14SomeUnknownException\x94\x93\x94\x8c\x08some arg\x94\x85\x94R\x94." + ex_type, ex, tb = clean_exception(pickled_ex) + assert ex_type == pickle.UnpicklingError + assert isinstance(ex, pickle.UnpicklingError) + assert ex.args[1] == pickled_ex + assert tb is None + + class MyServer(Server): default_port = 8756 diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3976857c9e..8f62a2bca3 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -70,6 +70,8 @@ async def test_work_stealing(c, s, a, b): await wait(futures) assert len(a.data) > 10 assert len(b.data) > 10 + del futures + print("Hi") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)