Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freeze protocol version #8452

Merged
merged 10 commits into from
Feb 9, 2024
125 changes: 113 additions & 12 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union

# third party
from packaging.version import parse
from result import OkErr
from result import Result

# relative
from .. import __version__
from ..serde.recursive import TYPE_BANK
from ..service.response import SyftError
from ..service.response import SyftException
Expand All @@ -30,29 +33,33 @@
PROTOCOL_TYPE = Union[str, int]


def natural_key(key: PROTOCOL_TYPE) -> list[int]:
def natural_key(key: PROTOCOL_TYPE) -> List[int]:
"""Define key for natural ordering of strings."""
if isinstance(key, int):
key = str(key)
return [int(s) if s.isdigit() else s for s in re.split("(\d+)", key)]


def sort_dict_naturally(d: dict) -> dict:
def sort_dict_naturally(d: Dict) -> Dict:
"""Sort dictionary by keys in natural order."""
return {k: d[k] for k in sorted(d.keys(), key=natural_key)}


def data_protocol_file_name():
def data_protocol_file_name() -> str:
return PROTOCOL_STATE_FILENAME


def data_protocol_dir():
return os.path.abspath(str(Path(__file__).parent))
def data_protocol_dir() -> Path:
return Path(os.path.abspath(str(Path(__file__).parent)))


def protocol_release_dir() -> Path:
return data_protocol_dir() / "releases"


class DataProtocol:
def __init__(self, filename: str) -> None:
self.file_path = Path(data_protocol_dir()) / filename
self.file_path = data_protocol_dir() / filename
self.load_state()

def load_state(self) -> None:
Expand All @@ -78,13 +85,31 @@ def _calculate_object_hash(klass: Type[SyftBaseObject]) -> str:

return hashlib.sha256(json.dumps(obj_meta_info).encode()).hexdigest()

def read_history(self) -> Dict:
@staticmethod
def read_json(file_path: Path) -> Dict:
try:
return json.loads(self.file_path.read_text())
return json.loads(file_path.read_text())
except Exception:
return {}

def save_history(self, history: dict) -> None:
def read_history(self) -> Dict:
protocol_history = self.read_json(self.file_path)

for version in protocol_history.keys():
if version == "dev":
continue
release_version_path = (
protocol_release_dir() / protocol_history[version]["release_name"]
)
released_version = self.read_json(file_path=release_version_path)
protocol_history[version] = released_version.get(version, {})

return protocol_history

def save_history(self, history: Dict) -> None:
for file_path in protocol_release_dir().iterdir():
for version in self.read_json(file_path):
history[version] = {"release_name": file_path.name}
self.file_path.write_text(json.dumps(history, indent=2) + "\n")

@property
Expand Down Expand Up @@ -136,7 +161,7 @@ def build_state(self, stop_key: Optional[str] = None) -> dict:
return state_dict
return state_dict

def diff_state(self, state: dict) -> tuple[dict, dict]:
def diff_state(self, state: Dict) -> tuple[Dict, Dict]:
compare_dict = defaultdict(dict) # what versions are in the latest code
object_diff = defaultdict(dict) # diff in latest code with saved json
for k in TYPE_BANK:
Expand Down Expand Up @@ -274,6 +299,7 @@ def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]:

keys = self.protocol_history.keys()
if "dev" not in keys:
self.validate_release()
print("You can't bump the protocol if there are no staged changes.")
return SyftError(
message="Failed to bump version as there are no staged changes."
Expand All @@ -287,11 +313,86 @@ def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]:

next_highest_protocol = highest_protocol + 1
self.protocol_history[str(next_highest_protocol)] = self.protocol_history["dev"]
self.freeze_release(self.protocol_history, str(next_highest_protocol))
del self.protocol_history["dev"]
self.save_history(self.protocol_history)
self.load_state()
return SyftSuccess(message=f"Protocol Updated to {next_highest_protocol}")

@staticmethod
def freeze_release(protocol_history: Dict, latest_protocol: str) -> None:
release_history = protocol_history[latest_protocol]
syft_version = parse(__version__)
release_file_name = f"{syft_version.public}.json"
release_file = protocol_release_dir() / release_file_name
release_file.write_text(
json.dumps({latest_protocol: release_history}, indent=2)
)

def validate_release(self) -> None:
protocol_history = self.read_json(self.file_path)
sorted_protocol_versions = sorted(protocol_history.keys(), key=natural_key)

latest_protocol = (
sorted_protocol_versions[-1] if len(sorted_protocol_versions) > 0 else None
)

if latest_protocol is None or latest_protocol == "dev":
return

release_name = protocol_history[latest_protocol]["release_name"]
syft_version = parse(release_name.split(".json")[0])
current_syft_version = parse(__version__)

if syft_version.base_version != current_syft_version.base_version:
return

print(
f"Current release {release_name} will be updated to {current_syft_version}"
)

curr_protocol_file_path: Path = protocol_release_dir() / release_name
new_protocol_file_path = (
protocol_release_dir() / f"{current_syft_version.public}.json"
)
curr_protocol_file_path.rename(new_protocol_file_path)
protocol_history[latest_protocol][
"release_name"
] = f"{current_syft_version}.json"

self.file_path.write_text(json.dumps(protocol_history, indent=2) + "\n")
self.read_history()

def revert_latest_protocol(self) -> Result[SyftSuccess, SyftError]:
"""Revert latest protocol changes to dev"""

# Get current protocol history
protocol_history = self.read_json(self.file_path)

# Get latest released protocol
sorted_protocol_versions = sorted(protocol_history.keys(), key=natural_key)
latest_protocol = (
sorted_protocol_versions[-1] if len(sorted_protocol_versions) > 0 else None
)

# If current protocol is dev, skip revert
if latest_protocol is None or latest_protocol == "dev":
return SyftError("Revert skipped !! Already running dev protocol.")

# Read the current released protocol
release_name = protocol_history[latest_protocol]["release_name"]
protocol_file_path: Path = protocol_release_dir() / release_name

released_protocol = self.read_json(protocol_file_path)
protocol_history["dev"] = released_protocol[latest_protocol]

# Delete the current released protocol
protocol_history.pop(latest_protocol)
protocol_file_path.unlink()

# Save history
self.save_history(protocol_history)

def check_protocol(self) -> Result[SyftSuccess, SyftError]:
if len(self.diff) != 0:
return SyftError(message="Protocol Changes Unstaged")
Expand Down Expand Up @@ -338,7 +439,7 @@ def has_dev(self) -> bool:
return False


def get_data_protocol():
def get_data_protocol() -> DataProtocol:
return DataProtocol(filename=data_protocol_file_name())


Expand All @@ -357,7 +458,7 @@ def check_or_stage_protocol() -> Result[SyftSuccess, SyftError]:
return data_protocol.check_or_stage_protocol()


def debox_arg_and_migrate(arg: Any, protocol_state: dict):
def debox_arg_and_migrate(arg: Any, protocol_state: dict) -> Any:
"""Debox the argument based on whether it is iterable or single entity."""
constructor = None
extra_args = []
Expand Down
Loading
Loading