Skip to content

Commit

Permalink
Merge pull request #8378 from OpenMined/safer_execution
Browse files Browse the repository at this point in the history
Safer execution
  • Loading branch information
teo-milea authored Jan 9, 2024
2 parents 3767bde + a2582ff commit d74cf0b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 29 deletions.
51 changes: 24 additions & 27 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@

# relative
from ...abstract_node import NodeType
from ...client.api import APIRegistry
from ...client.api import NodeIdentity
from ...client.client import PythonConnection
from ...client.enclave_client import EnclaveMetadata
from ...node.credentials import SyftVerifyKey
from ...protocol.data_protocol import get_data_protocol
from ...serde.deserialize import _deserialize
from ...serde.serializable import serializable
from ...serde.serialize import _serialize
Expand Down Expand Up @@ -994,7 +991,7 @@ def __init__(self, context):
node = context.node
job_service = node.get_service("jobservice")
action_service = node.get_service("actionservice")
user_service = node.get_service("userservice")
# user_service = node.get_service("userservice")

def job_set_n_iters(n_iters):
job = context.job
Expand All @@ -1011,24 +1008,24 @@ def job_increase_current_iter(current_iter):
job.current_iter += current_iter
job_service.update(context, job)

def set_api_registry():
user_signing_key = [
x.signing_key
for x in user_service.stash.partition.data.values()
if x.verify_key == context.credentials
][0]
data_protcol = get_data_protocol()
user_api = node.get_api(context.credentials, data_protcol.latest_version)
user_api.signing_key = user_signing_key
# We hardcode a python connection here since we have access to the node
# TODO: this is not secure
user_api.connection = PythonConnection(node=node)

APIRegistry.set_api_for(
node_uid=node.id,
user_verify_key=context.credentials,
api=user_api,
)
# def set_api_registry():
# user_signing_key = [
# x.signing_key
# for x in user_service.stash.partition.data.values()
# if x.verify_key == context.credentials
# ][0]
# data_protcol = get_data_protocol()
# user_api = node.get_api(context.credentials, data_protcol.latest_version)
# user_api.signing_key = user_signing_key
# # We hardcode a python connection here since we have access to the node
# # TODO: this is not secure
# user_api.connection = PythonConnection(node=node)

# APIRegistry.set_api_for(
# node_uid=node.id,
# user_verify_key=context.credentials,
# api=user_api,
# )

def launch_job(func: UserCode, **kwargs):
# relative
Expand All @@ -1049,8 +1046,8 @@ def launch_job(func: UserCode, **kwargs):
parent_job_id=context.job_id,
has_execute_permissions=True,
)
# set api in global scope to enable using .get(), .wait())
set_api_registry()
# # set api in global scope to enable using .get(), .wait())
# set_api_registry()

return job
except Exception as e:
Expand Down Expand Up @@ -1147,7 +1144,8 @@ def to_str(arg: Any) -> str:
# statisfy lint checker
result = None

_locals = locals()
# We only need access to local kwargs
_locals = {"kwargs": kwargs}
_globals = {}

for service_func_name, (linked_obj, _) in code_item.nested_codes.items():
Expand All @@ -1156,7 +1154,7 @@ def to_str(arg: Any) -> str:
raise Exception(code_obj.err())
_globals[service_func_name] = code_obj.ok()
_globals["print"] = print
exec(code_item.parsed_code, _globals, locals()) # nosec
exec(code_item.parsed_code, _globals, _locals) # nosec

evil_string = f"{code_item.unique_func_name}(**kwargs)"
try:
Expand All @@ -1170,7 +1168,6 @@ def to_str(arg: Any) -> str:
)
log_service = context.node.get_service("LogService")
log_service.append(context=context, uid=log_id, new_err=error_msg)

result = Err(
f"Exception encountered while running {code_item.service_func_name}"
", please contact the Node Admin for more info."
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/store/sqlite_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _set(self, key: UID, value: Any) -> None:
else:
insert_sql = (
f"insert into {self.table_name} (uid, repr, value) VALUES (?, ?, ?)" # nosec
) # nosec
)
data = _serialize(value, to_bytes=True)
res = self._execute(insert_sql, [str(key), _repr_debug_(value), data])
if res.is_err():
Expand All @@ -216,7 +216,7 @@ def _set(self, key: UID, value: Any) -> None:
def _update(self, key: UID, value: Any) -> None:
insert_sql = (
f"update {self.table_name} set uid = ?, repr = ?, value = ? where uid = ?" # nosec
) # nosec
)
data = _serialize(value, to_bytes=True)
res = self._execute(insert_sql, [str(key), _repr_debug_(value), data, str(key)])
if res.is_err():
Expand Down

0 comments on commit d74cf0b

Please sign in to comment.