Skip to content

Commit

Permalink
Merge pull request #57 from fal-ai/batuhan/fea-651-attach-tracebacks-…
Browse files Browse the repository at this point in the history
…on-isolate-executions

feat: preserve tracebacks accross agents when possible
  • Loading branch information
isidentical authored Dec 1, 2022
2 parents 5d54e9e + 29f8d61 commit f8ac6a1
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 21 deletions.
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ importlib-metadata = ">=4.4"
rich = ">=12.0"
grpcio = ">=1.49"
protobuf = "*"
tblib = "^1.7.0"

[tool.poetry.extras]
grpc = []
Expand Down
24 changes: 22 additions & 2 deletions src/isolate/connections/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import importlib
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Iterator, cast
from typing import TYPE_CHECKING, Any, Iterator, Optional, cast

from tblib import Traceback, TracebackParseError

if TYPE_CHECKING:
from typing import Protocol
Expand Down Expand Up @@ -55,6 +57,7 @@ def load_serialized_object(
raw_object: bytes,
*,
was_it_raised: bool = False,
stringized_traceback: Optional[str] = None,
) -> Any:
"""Load the given serialized object using the given serialization method. If
anything fails, then a SerializationError will be raised. If the was_it_raised
Expand All @@ -70,7 +73,7 @@ def load_serialized_object(
result = serialization_backend.loads(raw_object)

if was_it_raised:
raise result
raise prepare_exc(result, stringized_traceback=stringized_traceback)
else:
return result

Expand All @@ -91,3 +94,20 @@ def serialize_object(serialization_method: str, object: Any) -> bytes:
def is_agent() -> bool:
"""Returns true if the current process is an isolate agent."""
return os.environ.get(AGENT_SIGNATURE) == "1"


def prepare_exc(
exc: BaseException,
*,
stringized_traceback: Optional[str] = None,
) -> BaseException:
if stringized_traceback:
try:
traceback = Traceback.from_string(stringized_traceback).as_traceback()
except TracebackParseError:
traceback = None
else:
traceback = None

exc.__traceback__ = traceback
return exc
1 change: 1 addition & 0 deletions src/isolate/connections/grpc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def run(
method=method,
definition=serialize_object(method, executable),
was_it_raised=False,
stringized_traceback=None,
)

for partial_result in self._run_through_grpc(function):
Expand Down
4 changes: 4 additions & 0 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ def Run(
yield from self.log("Starting the execution of the input function.")

was_it_raised = False
stringized_tb = None
try:
result = function()
except BaseException as exc:
result = exc
was_it_raised = True
num_frames = len(traceback.extract_stack()[:-5])
stringized_tb = "".join(traceback.format_exc(limit=-num_frames))

yield from self.log("Completed the execution of the input function.")

Expand All @@ -79,6 +82,7 @@ def Run(
method=request.method,
definition=definition,
was_it_raised=was_it_raised,
stringized_traceback=stringized_tb,
)
yield definitions.PartialRunResult(
result=serialized_obj,
Expand Down
2 changes: 2 additions & 0 deletions src/isolate/connections/grpc/definitions/common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ message SerializedObject {
// A flag indicating whether the given object was raised (e.g. an exception
// that was captured) or not.
bool was_it_raised = 3;
// The stringized version of the traceback, if it was raised.
optional string stringized_traceback = 4;
}

message PartialRunResult {
Expand Down
22 changes: 11 additions & 11 deletions src/isolate/connections/grpc/definitions/common_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions src/isolate/connections/grpc/definitions/common_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class SerializedObject(google.protobuf.message.Message):
METHOD_FIELD_NUMBER: builtins.int
DEFINITION_FIELD_NUMBER: builtins.int
WAS_IT_RAISED_FIELD_NUMBER: builtins.int
STRINGIZED_TRACEBACK_FIELD_NUMBER: builtins.int
method: builtins.str
"""The serialization method used to serialize the the raw_object. Must be
present in the environment that is running the agent itself.
Expand All @@ -83,24 +84,46 @@ class SerializedObject(google.protobuf.message.Message):
"""A flag indicating whether the given object was raised (e.g. an exception
that was captured) or not.
"""
stringized_traceback: builtins.str
"""The stringized version of the traceback, if it was raised."""
def __init__(
self,
*,
method: builtins.str = ...,
definition: builtins.bytes = ...,
was_it_raised: builtins.bool = ...,
stringized_traceback: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_stringized_traceback",
b"_stringized_traceback",
"stringized_traceback",
b"stringized_traceback",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_stringized_traceback",
b"_stringized_traceback",
"definition",
b"definition",
"method",
b"method",
"stringized_traceback",
b"stringized_traceback",
"was_it_raised",
b"was_it_raised",
],
) -> None: ...
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_stringized_traceback", b"_stringized_traceback"
],
) -> typing_extensions.Literal["stringized_traceback"] | None: ...

global___SerializedObject = SerializedObject

Expand Down
9 changes: 7 additions & 2 deletions src/isolate/connections/grpc/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
and the Isolate Server to share."""

import functools
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

from isolate.connections.common import load_serialized_object, serialize_object
from isolate.connections.grpc import definitions
Expand All @@ -29,6 +29,7 @@ def _(message: definitions.SerializedObject) -> Any:
message.method,
message.definition,
was_it_raised=message.was_it_raised,
stringized_traceback=message.stringized_traceback,
)


Expand All @@ -53,11 +54,15 @@ def _(obj: Log) -> definitions.Log:


def to_serialized_object(
obj: Any, method: str, was_it_raised: bool = False
obj: Any,
method: str,
was_it_raised: bool = False,
stringized_traceback: Optional[str] = None,
) -> definitions.SerializedObject:
"""Convert a Python object into a gRPC message."""
return definitions.SerializedObject(
method=method,
definition=serialize_object(method, obj),
was_it_raised=was_it_raised,
stringized_traceback=stringized_traceback,
)
6 changes: 4 additions & 2 deletions src/isolate/connections/ipc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
EnvironmentConnection,
)
from isolate.connections._local import PythonExecutionBase, agent_startup
from isolate.connections.common import prepare_exc
from isolate.connections.ipc import agent
from isolate.logs import LogLevel, LogSource

Expand Down Expand Up @@ -191,11 +192,12 @@ def poll_until_result(

# TODO(fix): handle EOFError that might happen here (e.g. problematic
# serialization might cause it).
result, did_it_raise = connection.recv()
result, did_it_raise, stringized_traceback = connection.recv()

if did_it_raise:
raise result
raise prepare_exc(result, stringized_traceback=stringized_traceback)
else:
assert stringized_traceback is None
return result


Expand Down
8 changes: 5 additions & 3 deletions src/isolate/connections/ipc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import sys
import time
import traceback
from argparse import ArgumentParser
from contextlib import closing
from multiprocessing.connection import Client
Expand Down Expand Up @@ -109,21 +110,22 @@ def run_client(

result = None
did_it_raise = False
stringized_tb = None
try:
result = callable()
except BaseException as exc:
result = exc
did_it_raise = True
num_frames = len(traceback.extract_stack()[:-4])
stringized_tb = "".join(traceback.format_exc(limit=-num_frames))
finally:
try:
connection.send((result, did_it_raise))
connection.send((result, did_it_raise, stringized_tb))
except BaseException:
if did_it_raise:
# If we can't even send it through the connection
# still try to dump it to the stderr as the last
# resort.
import traceback

assert isinstance(result, BaseException)
traceback.print_exception(
type(result),
Expand Down
38 changes: 38 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
import traceback
from dataclasses import replace
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -144,6 +145,43 @@ def test_is_agent(self):
assert not is_agent()
assert not is_agent()

def test_tracebacks(self):
local_env = LocalPythonEnvironment()
local_env.apply_settings(
local_env.settings.replace(serialization_method="dill")
)

def long_function_chain():
def foo():
a = 1
b = 0
c = a / b
return c

def bar():
a = str() + str()
return 0 + foo() + 1

def baz():
return bar() + 1

return baz()

with self.open_connection(local_env, local_env.create()) as conn:
with pytest.raises(ZeroDivisionError) as exc:
conn.run(long_function_chain)

exception = "".join(
traceback.format_exception(
type(exc.value), exc.value, exc.value.__traceback__
)
)
assert "c = a / b" in exception
assert "return 0 + foo() + 1" in exception
assert "return bar() + 1" in exception
assert "return baz()" in exception
assert "conn.run(long_function_chain)" in exception


class TestPythonIPC(GenericPythonConnectionTests):
def open_connection(
Expand Down

0 comments on commit f8ac6a1

Please sign in to comment.