Skip to content

Commit

Permalink
Improve clean_exception
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Oct 11, 2024
1 parent d746d50 commit 53f0039
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
3 changes: 3 additions & 0 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import math
import os
import pickle
import sys
import tempfile
import threading
Expand Down Expand Up @@ -1706,6 +1707,8 @@ def clean_exception(
if isinstance(exception, (bytes, bytearray)):
try:
exception = protocol.pickle.loads(exception)
except pickle.UnpicklingError as e:
exception = e
except Exception:
exception = Exception(exception)
elif isinstance(exception, str):
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception as e:
logger.exception("Failed to serialize %s.", x)
raise pickle.PicklingError("Failed to serialize") from e
raise pickle.PicklingError("Failed to serialize", x, buffers) from e
if buffer_callback is not None:
for b in buffers:
buffer_callback(b)
Expand All @@ -94,4 +94,4 @@ def loads(x, *, buffers=()):
raise
except Exception as e:
logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
raise pickle.UnpicklingError("Failed to deserialize") from e
raise pickle.UnpicklingError("Failed to deserialize", x, buffers) from e
11 changes: 11 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import logging
import os
import pickle
import random
import socket
import sys
Expand Down Expand Up @@ -167,6 +168,16 @@ async def test_server_raises_on_blocked_handlers():
await comm.close()


@gen_test()
async def test_unreadable_exception_in_clean_exception():
pickled_ex = b"\x80\x04\x954\x00\x00\x00\x00\x00\x00\x00\x8c\x08__main__\x94\x8c\x14SomeUnknownException\x94\x93\x94\x8c\x08some arg\x94\x85\x94R\x94."
ex_type, ex, tb = clean_exception(pickled_ex)
assert ex_type == pickle.UnpicklingError
assert isinstance(ex, pickle.UnpicklingError)
assert ex.args[1] == pickled_ex
assert tb is None


class MyServer(Server):
default_port = 8756

Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ async def test_work_stealing(c, s, a, b):
await wait(futures)
assert len(a.data) > 10
assert len(b.data) > 10
del futures
print("Hi")


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
Expand Down

0 comments on commit 53f0039

Please sign in to comment.