Skip to content

Commit

Permalink
Merge pull request #9178 from OpenMined/eelco/l2-test-flow-fixes
Browse files Browse the repository at this point in the history
L2 flow fixes
  • Loading branch information
tcp authored Aug 16, 2024
2 parents 909b61d + e151321 commit 259f67d
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 10 deletions.
40 changes: 34 additions & 6 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,8 @@ def assets(self) -> DictTuple[str, Asset] | SyftError:
all_inputs = {}
inputs = self.input_policy_init_kwargs or {}
for vals in inputs.values():
all_inputs.update(vals)
# Only keep UIDs, filter out Constants
all_inputs.update({k: v for k, v in vals.items() if isinstance(v, UID)})

# map the action_id to the asset
used_assets: list[Asset] = []
Expand Down Expand Up @@ -753,20 +754,47 @@ def action_objects(self) -> dict:
action_objects = {
arg_name: str(uid)
for arg_name, uid in all_inputs.items()
if arg_name not in self.assets.keys()
if arg_name not in self.assets.keys() and isinstance(uid, UID)
}

return action_objects

@property
def constants(self) -> dict[str, Constant]:
if not self.input_policy_init_kwargs:
return {}

all_inputs = {}
for vals in self.input_policy_init_kwargs.values():
all_inputs.update(vals)

# filter out the assets
constants = {
arg_name: item
for arg_name, item in all_inputs.items()
if isinstance(item, Constant)
}

return constants

@property
def inputs(self) -> dict:
inputs = {}
if self.action_objects:
inputs["action_objects"] = self.action_objects
if self.assets:

assets = self.assets
action_objects = self.action_objects
constants = self.constants
if action_objects:
inputs["action_objects"] = action_objects
if assets:
inputs["assets"] = {
argument: asset._get_dict_for_user_code_repr()
for argument, asset in self.assets.items()
for argument, asset in assets.items()
}
if self.constants:
inputs["constants"] = {
argument: constant._get_dict_for_user_code_repr()
for argument, constant in constants.items()
}
return inputs

Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def check_user_code_id(self) -> Self:

@property
def result_id(self) -> UID | None:
if self.result is None:
return None
return self.result.id.id
if isinstance(self.result, ActionObject):
return self.result.id.id
return None

@property
def action_display_name(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions packages/syft/src/syft/service/output/output_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def from_ids(
)
else:
job_link = None

if input_ids is not None:
input_ids = {k: v for k, v in input_ids.items() if isinstance(v, UID)}
return cls(
output_ids=output_ids,
user_code_link=user_code_link,
Expand Down
9 changes: 9 additions & 0 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,15 @@ def transform_kwarg(
return Ok(obj.syft_action_data)
return Ok(self.val)

def _get_dict_for_user_code_repr(self) -> dict[str, Any]:
return self._coll_repr_()

def _coll_repr_(self) -> dict[str, Any]:
return {
"klass": self.klass.__qualname__,
"val": str(self.val),
}


@serializable()
class UserOwned(PolicyRule):
Expand Down
11 changes: 10 additions & 1 deletion packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,8 @@ def _create_output_history_for_deposited_result(
if input_policy is not None:
for input_ in input_policy.inputs.values():
input_ids.update(input_)

input_ids = {k: v for k, v in input_ids.items() if isinstance(v, UID)}
res = api.services.code.store_execution_output(
user_code_id=code.id,
outputs=result,
Expand Down Expand Up @@ -1088,6 +1090,7 @@ def _deposit_result_l2(
for inps in code.input_policy.inputs.values():
input_ids.update(inps)

input_ids = {k: v for k, v in input_ids.items() if isinstance(v, UID)}
res = api.services.code.store_execution_output(
user_code_id=code.id,
outputs=result,
Expand All @@ -1104,7 +1107,13 @@ def _deposit_result_l2(
else JobStatus.COMPLETED
)

existing_result = job.result.id if job.result is not None else None
existing_result = None
if isinstance(job.result, ActionObject):
existing_result = job.result.id
elif isinstance(job.result, Err):
existing_result = job.result
else:
existing_result = job.result
print(
f"Job({job.id}) Setting new result {existing_result} -> {job_info.result.id}"
)
Expand Down

0 comments on commit 259f67d

Please sign in to comment.