diff --git a/poetry.lock b/poetry.lock index 1780815..f1e1593 100644 --- a/poetry.lock +++ b/poetry.lock @@ -115,6 +115,14 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +[[package]] +name = "tblib" +version = "1.7.0" +description = "Traceback serialization library." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + [[package]] name = "typing-extensions" version = "4.4.0" @@ -160,7 +168,7 @@ server = [] [metadata] lock-version = "1.1" python-versions = ">=3.7,<4.0" -content-hash = "5d2172c3d31c2f2de606884e3e6ff842c5aa82637bb33e227a073cde5161daac" +content-hash = "0a5c65e4262d6400277faca3109be24bc7c3e83c88ec22bf9b3ad327f1a42825" [metadata.files] commonmark = [ @@ -258,6 +266,10 @@ six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +tblib = [ + {file = "tblib-1.7.0-py2.py3-none-any.whl", hash = "sha256:289fa7359e580950e7d9743eab36b0691f0310fce64dee7d9c31065b8f723e23"}, + {file = "tblib-1.7.0.tar.gz", hash = "sha256:059bd77306ea7b419d4f76016aef6d7027cc8a0785579b5aad198803435f882c"}, +] typing-extensions = [ {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, diff --git a/pyproject.toml b/pyproject.toml index ef442d4..415101e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ importlib-metadata = ">=4.4" rich = ">=12.0" grpcio = ">=1.49" protobuf = "*" +tblib = "^1.7.0" [tool.poetry.extras] grpc = [] diff --git a/src/isolate/connections/common.py b/src/isolate/connections/common.py index ba09603..04b927e 100644 --- a/src/isolate/connections/common.py +++ b/src/isolate/connections/common.py @@ -3,7 +3,9 @@ import importlib import os from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Iterator, cast +from typing import TYPE_CHECKING, Any, Iterator, Optional, cast + +from tblib import Traceback, TracebackParseError if TYPE_CHECKING: from typing import Protocol @@ -55,6 +57,7 @@ def load_serialized_object( raw_object: bytes, *, was_it_raised: bool = False, + stringized_traceback: Optional[str] = None, ) -> Any: """Load the given serialized object using the given serialization method. If anything fails, then a SerializationError will be raised. If the was_it_raised @@ -70,7 +73,7 @@ def load_serialized_object( result = serialization_backend.loads(raw_object) if was_it_raised: - raise result + raise prepare_exc(result, stringized_traceback=stringized_traceback) else: return result @@ -91,3 +94,20 @@ def serialize_object(serialization_method: str, object: Any) -> bytes: def is_agent() -> bool: """Returns true if the current process is an isolate agent.""" return os.environ.get(AGENT_SIGNATURE) == "1" + + +def prepare_exc( + exc: BaseException, + *, + stringized_traceback: Optional[str] = None, +) -> BaseException: + if stringized_traceback: + try: + traceback = Traceback.from_string(stringized_traceback).as_traceback() + except TracebackParseError: + traceback = None + else: + traceback = None + + exc.__traceback__ = traceback + return exc diff --git a/src/isolate/connections/grpc/_base.py b/src/isolate/connections/grpc/_base.py index fac792b..436e3b4 100644 --- a/src/isolate/connections/grpc/_base.py +++ b/src/isolate/connections/grpc/_base.py @@ -88,6 +88,7 @@ def run( method=method, definition=serialize_object(method, executable), was_it_raised=False, + stringized_traceback=None, ) for partial_result in self._run_through_grpc(function): diff --git a/src/isolate/connections/grpc/agent.py b/src/isolate/connections/grpc/agent.py index 32704fd..513403c 100644 --- a/src/isolate/connections/grpc/agent.py +++ b/src/isolate/connections/grpc/agent.py @@ -52,11 +52,14 @@ def Run( yield from self.log("Starting the execution of the input function.") was_it_raised = False + stringized_tb = None try: result = function() except BaseException as exc: result = exc was_it_raised = True + num_frames = len(traceback.extract_stack()[:-5]) + stringized_tb = "".join(traceback.format_exc(limit=-num_frames)) yield from self.log("Completed the execution of the input function.") @@ -79,6 +82,7 @@ def Run( method=request.method, definition=definition, was_it_raised=was_it_raised, + stringized_traceback=stringized_tb, ) yield definitions.PartialRunResult( result=serialized_obj, diff --git a/src/isolate/connections/grpc/definitions/common.proto b/src/isolate/connections/grpc/definitions/common.proto index 4974a61..14d1d72 100644 --- a/src/isolate/connections/grpc/definitions/common.proto +++ b/src/isolate/connections/grpc/definitions/common.proto @@ -9,6 +9,8 @@ message SerializedObject { // A flag indicating whether the given object was raised (e.g. an exception // that was captured) or not. bool was_it_raised = 3; + // The stringized version of the traceback, if it was raised. + optional string stringized_traceback = 4; } message PartialRunResult { diff --git a/src/isolate/connections/grpc/definitions/common_pb2.py b/src/isolate/connections/grpc/definitions/common_pb2.py index e05ef4c..680dd98 100644 --- a/src/isolate/connections/grpc/definitions/common_pb2.py +++ b/src/isolate/connections/grpc/definitions/common_pb2.py @@ -13,7 +13,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0c\x63ommon.proto"M\n\x10SerializedObject\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x12\n\ndefinition\x18\x02 \x01(\x0c\x12\x15\n\rwas_it_raised\x18\x03 \x01(\x08"n\n\x10PartialRunResult\x12\x13\n\x0bis_complete\x18\x01 \x01(\x08\x12\x12\n\x04logs\x18\x02 \x03(\x0b\x32\x04.Log\x12&\n\x06result\x18\x03 \x01(\x0b\x32\x11.SerializedObjectH\x00\x88\x01\x01\x42\t\n\x07_result"L\n\x03Log\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x1a\n\x06source\x18\x02 \x01(\x0e\x32\n.LogSource\x12\x18\n\x05level\x18\x03 \x01(\x0e\x32\t.LogLevel*.\n\tLogSource\x12\x0b\n\x07\x42UILDER\x10\x00\x12\n\n\x06\x42RIDGE\x10\x01\x12\x08\n\x04USER\x10\x02*Z\n\x08LogLevel\x12\t\n\x05TRACE\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\n\n\x06STDOUT\x10\x05\x12\n\n\x06STDERR\x10\x06\x62\x06proto3' + b'\n\x0c\x63ommon.proto"\x89\x01\n\x10SerializedObject\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x12\n\ndefinition\x18\x02 \x01(\x0c\x12\x15\n\rwas_it_raised\x18\x03 \x01(\x08\x12!\n\x14stringized_traceback\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x17\n\x15_stringized_traceback"n\n\x10PartialRunResult\x12\x13\n\x0bis_complete\x18\x01 \x01(\x08\x12\x12\n\x04logs\x18\x02 \x03(\x0b\x32\x04.Log\x12&\n\x06result\x18\x03 \x01(\x0b\x32\x11.SerializedObjectH\x00\x88\x01\x01\x42\t\n\x07_result"L\n\x03Log\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x1a\n\x06source\x18\x02 \x01(\x0e\x32\n.LogSource\x12\x18\n\x05level\x18\x03 \x01(\x0e\x32\t.LogLevel*.\n\tLogSource\x12\x0b\n\x07\x42UILDER\x10\x00\x12\n\n\x06\x42RIDGE\x10\x01\x12\x08\n\x04USER\x10\x02*Z\n\x08LogLevel\x12\t\n\x05TRACE\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\n\n\x06STDOUT\x10\x05\x12\n\n\x06STDERR\x10\x06\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -21,14 +21,14 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _LOGSOURCE._serialized_start = 285 - _LOGSOURCE._serialized_end = 331 - _LOGLEVEL._serialized_start = 333 - _LOGLEVEL._serialized_end = 423 - _SERIALIZEDOBJECT._serialized_start = 16 - _SERIALIZEDOBJECT._serialized_end = 93 - _PARTIALRUNRESULT._serialized_start = 95 - _PARTIALRUNRESULT._serialized_end = 205 - _LOG._serialized_start = 207 - _LOG._serialized_end = 283 + _LOGSOURCE._serialized_start = 346 + _LOGSOURCE._serialized_end = 392 + _LOGLEVEL._serialized_start = 394 + _LOGLEVEL._serialized_end = 484 + _SERIALIZEDOBJECT._serialized_start = 17 + _SERIALIZEDOBJECT._serialized_end = 154 + _PARTIALRUNRESULT._serialized_start = 156 + _PARTIALRUNRESULT._serialized_end = 266 + _LOG._serialized_start = 268 + _LOG._serialized_end = 344 # @@protoc_insertion_point(module_scope) diff --git a/src/isolate/connections/grpc/definitions/common_pb2.pyi b/src/isolate/connections/grpc/definitions/common_pb2.pyi index 0132e17..46053f9 100644 --- a/src/isolate/connections/grpc/definitions/common_pb2.pyi +++ b/src/isolate/connections/grpc/definitions/common_pb2.pyi @@ -73,6 +73,7 @@ class SerializedObject(google.protobuf.message.Message): METHOD_FIELD_NUMBER: builtins.int DEFINITION_FIELD_NUMBER: builtins.int WAS_IT_RAISED_FIELD_NUMBER: builtins.int + STRINGIZED_TRACEBACK_FIELD_NUMBER: builtins.int method: builtins.str """The serialization method used to serialize the the raw_object. Must be present in the environment that is running the agent itself. @@ -83,24 +84,46 @@ class SerializedObject(google.protobuf.message.Message): """A flag indicating whether the given object was raised (e.g. an exception that was captured) or not. """ + stringized_traceback: builtins.str + """The stringized version of the traceback, if it was raised.""" def __init__( self, *, method: builtins.str = ..., definition: builtins.bytes = ..., was_it_raised: builtins.bool = ..., + stringized_traceback: builtins.str | None = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_stringized_traceback", + b"_stringized_traceback", + "stringized_traceback", + b"stringized_traceback", + ], + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_stringized_traceback", + b"_stringized_traceback", "definition", b"definition", "method", b"method", + "stringized_traceback", + b"stringized_traceback", "was_it_raised", b"was_it_raised", ], ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal[ + "_stringized_traceback", b"_stringized_traceback" + ], + ) -> typing_extensions.Literal["stringized_traceback"] | None: ... global___SerializedObject = SerializedObject diff --git a/src/isolate/connections/grpc/interface.py b/src/isolate/connections/grpc/interface.py index 01e485b..808fefe 100644 --- a/src/isolate/connections/grpc/interface.py +++ b/src/isolate/connections/grpc/interface.py @@ -2,7 +2,7 @@ and the Isolate Server to share.""" import functools -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from isolate.connections.common import load_serialized_object, serialize_object from isolate.connections.grpc import definitions @@ -29,6 +29,7 @@ def _(message: definitions.SerializedObject) -> Any: message.method, message.definition, was_it_raised=message.was_it_raised, + stringized_traceback=message.stringized_traceback, ) @@ -53,11 +54,15 @@ def _(obj: Log) -> definitions.Log: def to_serialized_object( - obj: Any, method: str, was_it_raised: bool = False + obj: Any, + method: str, + was_it_raised: bool = False, + stringized_traceback: Optional[str] = None, ) -> definitions.SerializedObject: """Convert a Python object into a gRPC message.""" return definitions.SerializedObject( method=method, definition=serialize_object(method, obj), was_it_raised=was_it_raised, + stringized_traceback=stringized_traceback, ) diff --git a/src/isolate/connections/ipc/_base.py b/src/isolate/connections/ipc/_base.py index 80b6ce4..7ddc47d 100644 --- a/src/isolate/connections/ipc/_base.py +++ b/src/isolate/connections/ipc/_base.py @@ -24,6 +24,7 @@ EnvironmentConnection, ) from isolate.connections._local import PythonExecutionBase, agent_startup +from isolate.connections.common import prepare_exc from isolate.connections.ipc import agent from isolate.logs import LogLevel, LogSource @@ -191,11 +192,12 @@ def poll_until_result( # TODO(fix): handle EOFError that might happen here (e.g. problematic # serialization might cause it). - result, did_it_raise = connection.recv() + result, did_it_raise, stringized_traceback = connection.recv() if did_it_raise: - raise result + raise prepare_exc(result, stringized_traceback=stringized_traceback) else: + assert stringized_traceback is None return result diff --git a/src/isolate/connections/ipc/agent.py b/src/isolate/connections/ipc/agent.py index df983b3..9adc36d 100644 --- a/src/isolate/connections/ipc/agent.py +++ b/src/isolate/connections/ipc/agent.py @@ -20,6 +20,7 @@ import os import sys import time +import traceback from argparse import ArgumentParser from contextlib import closing from multiprocessing.connection import Client @@ -109,21 +110,22 @@ def run_client( result = None did_it_raise = False + stringized_tb = None try: result = callable() except BaseException as exc: result = exc did_it_raise = True + num_frames = len(traceback.extract_stack()[:-4]) + stringized_tb = "".join(traceback.format_exc(limit=-num_frames)) finally: try: - connection.send((result, did_it_raise)) + connection.send((result, did_it_raise, stringized_tb)) except BaseException: if did_it_raise: # If we can't even send it through the connection # still try to dump it to the stderr as the last # resort. - import traceback - assert isinstance(result, BaseException) traceback.print_exception( type(result), diff --git a/tests/test_connections.py b/tests/test_connections.py index 0ff437d..17ca5be 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -1,4 +1,5 @@ import operator +import traceback from dataclasses import replace from functools import partial from pathlib import Path @@ -144,6 +145,43 @@ def test_is_agent(self): assert not is_agent() assert not is_agent() + def test_tracebacks(self): + local_env = LocalPythonEnvironment() + local_env.apply_settings( + local_env.settings.replace(serialization_method="dill") + ) + + def long_function_chain(): + def foo(): + a = 1 + b = 0 + c = a / b + return c + + def bar(): + a = str() + str() + return 0 + foo() + 1 + + def baz(): + return bar() + 1 + + return baz() + + with self.open_connection(local_env, local_env.create()) as conn: + with pytest.raises(ZeroDivisionError) as exc: + conn.run(long_function_chain) + + exception = "".join( + traceback.format_exception( + type(exc.value), exc.value, exc.value.__traceback__ + ) + ) + assert "c = a / b" in exception + assert "return 0 + foo() + 1" in exception + assert "return bar() + 1" in exception + assert "return baz()" in exception + assert "conn.run(long_function_chain)" in exception + class TestPythonIPC(GenericPythonConnectionTests): def open_connection(