Skip to content

Commit

Permalink
POC for Rust-like error handling using structural pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
jcardonnet committed May 23, 2024
1 parent 755c01f commit d8a30de
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 66 deletions.
119 changes: 61 additions & 58 deletions packages/syft/src/syft/service/settings/settings_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# stdlib

# stdlib

# stdlib
# type: ignore

# third party
from result import Err
from result import Ok
from result import Result

# relative
from ...serde.serializable import serializable
Expand Down Expand Up @@ -36,35 +31,44 @@ def __init__(self, store: DocumentStore) -> None:
self.stash = SettingsStash(store=store)

@service_method(path="settings.get", name="get")
def get(self, context: UnauthedServiceContext) -> Result[Ok, Err]:
"""Get Settings"""

result = self.stash.get_all(context.node.signing_key.verify_key)
if result.is_ok():
settings = result.ok()
# check if the settings list is empty
if len(settings) == 0:
def get(self, context: UnauthedServiceContext) -> NodeSettings | SyftError:
"""
Get the Node Settings
Returns:
NodeSettings | SyftError : The Node Settings or an error if no settings are found.
"""

result = self.stash.get(context.node.signing_key.verify_key)

match result: # type: ignore
case Ok(None):
return SyftError(message="No settings found")
result = settings[0]
return Ok(result)
else:
return SyftError(message=result.err())
case Ok(NodeSettings() as settings):
return settings
case Err(err_message):
return SyftError(message=err_message)

@service_method(path="settings.set", name="set")
def set(
self, context: AuthedServiceContext, settings: NodeSettings
) -> Result[Ok, Err]:
"""Set a new the Node Settings"""
) -> NodeSettings | SyftError:
"""
Set a new the Node Settings
Returns:
NodeSettings | SyftError : The new Node Settings or an error if the settings could not be set.
"""
result = self.stash.set(context.credentials, settings)
if result.is_ok():
return result
else:
return SyftError(message=result.err())

match result:
case Ok(settings):
return settings
case Err(err_message):
return SyftError(message=err_message)

@service_method(path="settings.update", name="update", autosplat=["settings"])
def update(
self, context: AuthedServiceContext, settings: NodeSettingsUpdate
) -> Result[SyftSuccess, SyftError]:
) -> SyftSuccess | SyftError:
"""
Update the Node Settings using the provided values.
Expand All @@ -84,34 +88,32 @@ def update(
association_request_auto_approval: Optional[bool]
Returns:
Result[SyftSuccess, SyftError]: A result indicating the success or failure of the update operation.
SyftSuccess | SyftError: A result indicating the success or failure of the update operation.
Example:
>>> node_client.update(name='foo', organization='bar', description='baz', signup_enabled=True)
SyftSuccess: Settings updated successfully.
"""

result = self.stash.get_all(context.credentials)
if result.is_ok():
current_settings = result.ok()
if len(current_settings) > 0:
new_settings = current_settings[0].model_copy(
result = self.get(context)
match result: # type: ignore
case NodeSettings():
new_settings = result.model_copy(
update=settings.to_dict(exclude_empty=True)
)
update_result = self.stash.update(context.credentials, new_settings)
if update_result.is_ok():
return SyftSuccess(
message=(
"Settings updated successfully. "
+ "You must call <client>.refresh() to sync your client with the changes."
match update_result:
case Ok():
return SyftSuccess(
message=(
"Settings updated successfully. "
+ "You must call <client>.refresh() to sync your client with the changes."
)
)
)
else:
return SyftError(message=update_result.err())
else:
return SyftError(message="No settings found")
else:
return SyftError(message=result.err())
case Err(err_message):
return SyftError(message=err_message)
case SyftError():
return result

@service_method(
path="settings.enable_notifications",
Expand Down Expand Up @@ -160,16 +162,15 @@ def allow_guest_signup(
"""Enable/Disable Registration for Data Scientist or Guest Users."""
flags.CAN_REGISTER = enable

method = context.node.get_service_method(SettingsService.update)
settings = NodeSettingsUpdate(signup_enabled=enable)

result = method(context=context, settings=settings)

if isinstance(result, SyftError):
return SyftError(message=f"Failed to update settings: {result.err()}")
new_settings = NodeSettingsUpdate(signup_enabled=enable)
result = self.update(context, settings=new_settings)

message = "enabled" if enable else "disabled"
return SyftSuccess(message=f"Registration feature successfully {message}")
match result: # type: ignore
case SyftSuccess():
flag = "enabled" if enable else "disabled"
return SyftSuccess(message=f"Registration feature successfully {flag}")
case SyftError():
return SyftError(message=f"Failed to update settings: {result}")

@service_method(
path="settings.allow_association_request_auto_approval",
Expand All @@ -180,10 +181,12 @@ def allow_association_request_auto_approval(
) -> SyftSuccess | SyftError:
new_settings = NodeSettingsUpdate(association_request_auto_approval=enable)
result = self.update(context, settings=new_settings)
if isinstance(result, SyftError):
return result

message = "enabled" if enable else "disabled"
return SyftSuccess(
message="Association request auto-approval successfully " + message
)
match result: # type: ignore
case SyftSuccess():
flag = "enabled" if enable else "disabled"
return SyftSuccess(
message=f"Association request auto-approval successfully {flag}"
)
case SyftError():
return SyftError(message=f"Failed to update settings: {result}")
42 changes: 34 additions & 8 deletions packages/syft/src/syft/service/settings/settings_stash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# stdlib

# stdlib
from typing import Any

# third party
from result import Err
from result import Ok
from result import Result

# relative
Expand Down Expand Up @@ -30,6 +35,22 @@ class SettingsStash(BaseUIDStoreStash):
def __init__(self, store: DocumentStore) -> None:
super().__init__(store=store)

def check_type(self, obj: Any) -> Result[NodeSettings, str]:
if isinstance(obj, NodeSettings):
return Ok(obj)
else:
return Err(f"{type(obj)} does not match required type: {NodeSettings}")

def get(self, credentials: SyftVerifyKey) -> Result[NodeSettings | None, str]:
result = self.get_all(credentials=credentials)
match result:
case Ok(settings) if settings:
return Ok(None)
case Ok(settings):
return Ok(settings[0]) # type: ignore
case Err(err_message):
return Err(err_message)

def set(
self,
credentials: SyftVerifyKey,
Expand All @@ -38,20 +59,25 @@ def set(
add_storage_permission: bool = True,
ignore_duplicates: bool = False,
) -> Result[NodeSettings, str]:
res = self.check_type(settings, self.object_type)
result = self.check_type(settings)
# we dont use and_then logic here as it is hard because of the order of the arguments
if res.is_err():
return res
return super().set(credentials=credentials, obj=res.ok())
match result:
case Ok(obj):
return super().set(credentials=credentials, obj=obj) # type: ignore
case Err(error):
return Err(error)

def update(
self,
credentials: SyftVerifyKey,
settings: NodeSettings,
has_permission: bool = False,
) -> Result[NodeSettings, str]:
res = self.check_type(settings, self.object_type)
result = self.check_type(settings)
# we dont use and_then logic here as it is hard because of the order of the arguments
if res.is_err():
return res
return super().update(credentials=credentials, obj=res.ok())

match result:
case Ok(obj):
return super().update(credentials=credentials, obj=obj) # type: ignore
case Err(error):
return Err(error)

0 comments on commit d8a30de

Please sign in to comment.