Skip to content

Commit

Permalink
Merge branch 'dev' into aziz/job_kill
Browse files Browse the repository at this point in the history
  • Loading branch information
abyesilyurt authored May 14, 2024
2 parents 7d77b36 + bf873e3 commit 2c3815a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
36 changes: 22 additions & 14 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,15 +770,30 @@ def add_output_policy_ids(cls, values: Any) -> Any:
def kwargs(self) -> dict[Any, Any] | None:
return self.input_policy_init_kwargs

def __call__(self, *args: Any, syft_no_node: bool = False, **kwargs: Any) -> Any:
def __call__(
self,
*args: Any,
syft_no_node: bool = False,
blocking: bool = False,
time_alive: int | None = None,
n_consumers: int = 2,
**kwargs: Any,
) -> Any:
if syft_no_node:
return self.local_call(*args, **kwargs)
return self._ephemeral_node_call(*args, **kwargs)
return self._ephemeral_node_call(
*args,
time_alive=time_alive,
n_consumers=n_consumers,
blocking=blocking,
**kwargs,
)

def local_call(self, *args: Any, **kwargs: Any) -> Any:
# only run this on the client side
if self.local_function:
tree = ast.parse(inspect.getsource(self.local_function))
source = dedent(inspect.getsource(self.local_function))
tree = ast.parse(source)

# check there are no globals
v = GlobalsVisitor()
Expand All @@ -803,9 +818,10 @@ def local_call(self, *args: Any, **kwargs: Any) -> Any:

def _ephemeral_node_call(
self,
time_alive: int | None = None,
n_consumers: int | None = None,
*args: Any,
time_alive: int | None = None,
n_consumers: int = 2,
blocking: bool = False,
**kwargs: Any,
) -> Any:
# relative
Expand All @@ -814,15 +830,7 @@ def _ephemeral_node_call(
# Right now we only create a number of workers
# In the future we might need to have the same pools/images as well

if n_consumers is None:
print(
SyftInfo(
message="Creating a node with n_consumers=2 (the default value)"
)
)
n_consumers = 2

if time_alive is None and "blocking" in kwargs and not kwargs["blocking"]:
if time_alive is None and not blocking:
print(
SyftInfo(
message="Closing the node after time_alive=300 (the default value)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def compute() -> int:

client_low_ds.refresh()
res = client_low_ds.code.compute(blocking=True)
assert res == compute(blocking=True).get()
assert res == compute(syft_no_node=True)


def test_sync_with_error(low_worker, high_worker):
Expand Down
8 changes: 7 additions & 1 deletion packages/syft/tests/syft/users/local_execution_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# stdlib
from collections import OrderedDict
import sys

# third party
import numpy as np
import pytest

# syft absolute
import syft as sy
from syft.client.api import APIRegistry


@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
def test_local_execution(worker):
root_domain_client = worker.root_client
dataset = sy.Dataset(
Expand Down Expand Up @@ -40,5 +43,8 @@ def my_func(x):
return x + 1

# time.sleep(10)
local_res = my_func(x=asset, time_alive=1)
local_res = my_func(
x=asset,
time_alive=1,
)
assert (local_res == np.array([2, 2, 2])).all()
2 changes: 1 addition & 1 deletion packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_duplicated_user_code(worker, guest_client: User) -> None:

# request the a different function name but same content will also succeed
# flaky if not blocking
mock_syft_func_2(blocking=True)
mock_syft_func_2(syft_no_node=True)
result = guest_client.api.services.code.request_code_execution(mock_syft_func_2)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 2
Expand Down

0 comments on commit 2c3815a

Please sign in to comment.