diff --git a/icechunk-python/tests/test_stateful_repo_ops.py b/icechunk-python/tests/test_stateful_repo_ops.py index 159a464e2..b1052bade 100644 --- a/icechunk-python/tests/test_stateful_repo_ops.py +++ b/icechunk-python/tests/test_stateful_repo_ops.py @@ -8,11 +8,12 @@ from collections.abc import Iterator from dataclasses import dataclass from functools import partial -from typing import Any, Self, cast +from typing import Any, Literal, Self, cast import numpy as np import pytest +import icechunk from zarr.core.buffer import Buffer, default_buffer_prototype pytest.importorskip("hypothesis") @@ -35,7 +36,13 @@ ) import zarr.testing.strategies as zrst -from icechunk import IcechunkError, Repository, SnapshotInfo, in_memory_storage +from icechunk import ( + IcechunkError, + Repository, + SnapshotInfo, + Storage, + in_memory_storage, +) from zarr.testing.stateful import SyncStoreWrapper # JSON file contents, keep it simple @@ -111,12 +118,13 @@ def __init__(self, **kwargs: Any) -> None: self.HEAD: None | str = None self.branch: None | str = None + # we maintain a list of `commits` == snapshots in repo object file + # and `ondisk_snaps` = commits + expired_snaphsots + # expired snapshots are removed from the repo object file and can be garbage collected self.commits: dict[str, CommitModel] = {} self.ondisk_snaps: dict[str, CommitModel] = {} self.tags: dict[str, TagModel] = {} - # TODO: This is only tracking the HEAD, - # Should we model the branch as an ordered list of commits? - self.branches: dict[str, str] = {} + self.branch_heads: dict[str, str] = {} # a tag once created, can never be recreated even after expiration self.created_tags: set[str] = set() @@ -124,7 +132,7 @@ def __init__(self, **kwargs: Any) -> None: def __repr__(self) -> str: return textwrap.dedent(f""" - Branches: {tuple(self.branches.keys())!r} + Branches: {tuple(self.branch_heads.keys())!r} Tags: {tuple(self.tags.keys())!r}""").strip("\n") def __setitem__(self, key: str, value: Buffer) -> None: @@ -147,6 +155,11 @@ def __getitem__(self, key: str) -> Buffer: def has_commits(self) -> bool: return bool(self.commits) + @property + def commit_times(self) -> list[datetime.datetime]: + """Return sorted list of all commit times.""" + return sorted(c.written_at for c in self.commits.values()) + def commit(self, snap: SnapshotInfo) -> None: ref = snap.id self.commits[ref] = CommitModel.from_snapshot_and_store( @@ -157,7 +170,12 @@ def commit(self, snap: SnapshotInfo) -> None: self.HEAD = ref assert self.branch is not None - self.branches[self.branch] = ref + self.branch_heads[self.branch] = ref + + def amend(self, snap: SnapshotInfo) -> None: + """Amend the HEAD commit.""" + # this is simpe because we aren't modeling the branch as a list of commits + self.commit(snap) def checkout_commit(self, ref: str) -> None: assert str(ref) in self.commits @@ -170,18 +188,19 @@ def checkout_commit(self, ref: str) -> None: def create_branch(self, name: str, commit: str) -> None: assert commit in self.commits - self.branches[name] = commit + self.branch_heads[name] = commit def checkout_branch(self, ref: str) -> None: - self.checkout_commit(self.branches[ref]) + self.checkout_commit(self.branch_heads[ref]) self.branch = ref + # TODO: add `from_snapshot_id` to this def reset_branch(self, branch: str, commit: str) -> None: assert commit in self.commits - self.branches[branch] = commit + self.branch_heads[branch] = commit def delete_branch(self, branch_name: str) -> None: - del self.branches[branch_name] + del self.branch_heads[branch_name] def delete_tag(self, tag: str) -> None: del self.tags[tag] @@ -200,7 +219,7 @@ def list_prefix(self, prefix: str) -> tuple[str, ...]: def refs_iter(self) -> Iterator[str]: tag_iter = map(operator.attrgetter("commit_id"), self.tags.values()) - return itertools.chain(self.branches.values(), tag_iter) + return itertools.chain(self.branch_heads.values(), tag_iter) def expire_snapshots( self, @@ -211,7 +230,7 @@ def expire_snapshots( ) -> ExpireInfo: # model this exactly like icechunk does. expired_snaps = set() - branch_pointees = set(self.branches.values()) + branch_pointees = set(self.branch_heads.values()) tag_pointees = set(map(operator.attrgetter("commit_id"), self.tags.values())) for snap in self.commits.values(): if ( @@ -219,7 +238,7 @@ def expire_snapshots( and snap.parent_id is not None and (delete_expired_tags or snap.id not in tag_pointees) and ( - (delete_expired_branches and self.branches["main"] != snap.id) + (delete_expired_branches and self.branch_heads["main"] != snap.id) or snap.id not in branch_pointees ) ): @@ -229,6 +248,7 @@ def expire_snapshots( for id in expired_snaps: # notice we don't delete from self.ondisk_snaps, those can still be deleted by GC + # however we do pop from `commits` since that is a list of unexpired snaps self.commits.pop(id, None) for c in self.commits.values(): @@ -248,12 +268,12 @@ def expire_snapshots( if delete_expired_branches: branches_to_delete = { k - for k, v in self.branches.items() + for k, v in self.branch_heads.items() if k != DEFAULT_BRANCH and v in expired_snaps } note(f"deleting branches {branches_to_delete=!r}") for branch in branches_to_delete: - note(f"deleting {branch=!r}, {self.branches[branch]=!r}") + note(f"deleting {branch=!r}, {self.branch_heads[branch]=!r}") self.delete_branch(branch) else: branches_to_delete = set() @@ -300,15 +320,12 @@ def __init__(self) -> None: note("----------") self.model = Model() - self.commit_times: list[datetime.datetime] = [] + self.storage: Storage | None = None - @initialize(data=st.data(), target=branches) - def initialize(self, data: st.DataObject) -> str: - # FIXME: currently this test is IC2 only - spec_version = data.draw( - st.one_of(st.integers(min_value=2, max_value=2), st.none()) - ) - self.repo = Repository.create(in_memory_storage(), spec_version=spec_version) + @initialize(data=st.data(), target=branches, spec_version=st.sampled_from([1, 2])) + def initialize(self, data: st.DataObject, spec_version: Literal[1, 2]) -> str: + self.storage = in_memory_storage() + self.repo = Repository.create(self.storage, spec_version=spec_version) self.session = self.repo.writable_session(DEFAULT_BRANCH) snap = next(iter(self.repo.ancestry(branch=DEFAULT_BRANCH))) @@ -348,6 +365,34 @@ def set_doc(self, path: str, value: Buffer) -> None: with pytest.raises(IcechunkError, match="read-only store"): self.sync_store.set(path, value) + @rule() + @precondition(lambda self: self.repo.spec_version == 1) + def upgrade_spec_version(self) -> None: + # don't test simple cases of catching error upgradging a v2 spec + # that should be covered in unit tests + icechunk.upgrade_icechunk_repository(self.repo) + # TODO: remove the reopen after https://github.com/earth-mover/icechunk/issues/1521 + self.reopen_repository() + + @rule() + def reopen_repository(self) -> None: + """Reopen the repository from storage to get fresh state. + + This discards any uncommitted changes. + """ + assert self.storage is not None, "storage must be initialized" + self.repo = Repository.open(self.storage) + note(f"Reopened repository (spec_version={self.repo.spec_version})") + + # Reopening discards uncommitted changes - reset model to last committed state + branch = ( + self.model.branch + if self.model.branch in self.model.branch_heads + else DEFAULT_BRANCH + ) + self.session = self.repo.writable_session(branch) + self.model.checkout_branch(branch) + @rule(message=st.text(max_size=MAX_TEXT_SIZE), target=commits) @precondition(lambda self: self.model.changes_made) def commit(self, message: str) -> str: @@ -359,7 +404,30 @@ def commit(self, message: str) -> str: self.session = self.repo.writable_session(branch) note(f"Created commit: {snapinfo!r}") self.model.commit(snapinfo) - self.commit_times.append(snapinfo.written_at) + return commit_id + + @rule(message=st.text(max_size=MAX_TEXT_SIZE), target=commits) + # TODO: update changes made rule depending on result of + # https://github.com/earth-mover/icechunk/issues/1532 + @precondition( + lambda self: (self.model.changes_made) + and (self.repo.spec_version >= 2) + and len(self.model.commits) > 1 + ) + def amend(self, message: str) -> str: + branch = self.session.branch + assert branch is not None + old_head = next(iter(self.repo.ancestry(branch=branch))) + note(f"Amending commit on branch {branch!r} with id {old_head!r}") + + commit_id = self.session.amend(message) + snapinfo = next(iter(self.repo.ancestry(branch=branch))) + assert snapinfo.id == commit_id + note(f"Amended commit: {snapinfo!r}") + self.session = self.repo.writable_session(branch) + + # Update model + self.model.amend(snapinfo) return commit_id @rule(ref=commits) @@ -395,7 +463,8 @@ def checkout_tag(self, ref: str) -> None: @rule(ref=branches) def checkout_branch(self, ref: str) -> None: # TODO: sometimes readonly? - if self.model.branches.get(ref) in self.model.commits: + branch_head = self.model.branch_heads.get(ref) + if branch_head is not None and branch_head in self.model.commits: note(f"Checking out branch {ref!r}") self.session = self.repo.writable_session(ref) assert not self.session.read_only @@ -410,7 +479,7 @@ def create_branch(self, name: str, commit: str) -> str: note(f"Creating branch {name!r}") # we can create a tag and branch with the same name - if name not in self.model.branches and commit in self.model.commits: + if name not in self.model.branch_heads and commit in self.model.commits: self.repo.create_branch(name, commit) self.model.create_branch(name, commit) else: @@ -462,13 +531,13 @@ def discard_changes(self) -> None: @precondition(lambda self: not self.model.changes_made) @rule(branch=branches, commit=commits) def reset_branch(self, branch: str, commit: str) -> None: - if branch not in self.model.branches or commit not in self.model.commits: + if branch not in self.model.branch_heads or commit not in self.model.commits: note(f"resetting branch {branch}, expecting error.") with pytest.raises(IcechunkError): self.repo.reset_branch(branch, commit) else: note( - f"resetting branch {branch} from {self.model.branches[branch]} to {commit}" + f"resetting branch {branch} from {self.model.branch_heads[branch]} to {commit}" ) self.repo.reset_branch(branch, commit) self.model.reset_branch(branch, commit) @@ -486,7 +555,7 @@ def maybe_checkout_branch( @rule(branch=consumes(branches)) def delete_branch(self, branch: str) -> None: note(f"Deleting branch {branch!r}") - if branch in self.model.branches: + if branch in self.model.branch_heads: if branch == DEFAULT_BRANCH: note("Expecting error.") with pytest.raises( @@ -503,7 +572,10 @@ def delete_branch(self, branch: str) -> None: with pytest.raises(IcechunkError): self.repo.delete_branch(branch) - @precondition(lambda self: bool(self.commit_times)) + # TODO: v1 has bugs in expire_snapshots, only test for v2 + # https://github.com/earth-mover/icechunk/issues/1520 + # https://github.com/earth-mover/icechunk/issues/1534 + @precondition(lambda self: bool(self.model.commits) and self.repo.spec_version == 2) @rule( data=st.data(), delta=st.timedeltas( @@ -519,17 +591,29 @@ def expire_snapshots( delete_expired_branches: bool, delete_expired_tags: bool, ) -> None: - commit_time = data.draw(st.sampled_from(self.commit_times)) + commit_time = data.draw(st.sampled_from(self.model.commit_times)) older_than = commit_time + delta note( f"Expiring snapshots {older_than=!r}, ({commit_time=!r}, {delta=!r}), {delete_expired_branches=!r}, {delete_expired_tags=!r}" ) + + # Track branches and tags before expiration + branches_before = set(self.repo.list_branches()) + tags_before = set(self.repo.list_tags()) + actual = self.repo.expire_snapshots( older_than, delete_expired_branches=delete_expired_branches, delete_expired_tags=delete_expired_tags, ) note(f"repo expired snaps={actual!r}") + + # Track branches and tags after expiration + branches_after = set(self.repo.list_branches()) + tags_after = set(self.repo.list_tags()) + actual_deleted_branches = branches_before - branches_after + actual_deleted_tags = tags_before - tags_after + expected = self.model.expire_snapshots( older_than, delete_expired_branches=delete_expired_branches, @@ -537,13 +621,34 @@ def expire_snapshots( ) note(f"from model: {expected}") note(f"actual: {actual}") + note(f"actual_deleted_branches: {actual_deleted_branches}") + note(f"actual_deleted_tags: {actual_deleted_tags}") + assert self.initial_snapshot.id not in actual assert actual == expected.expired_snapshots, (actual, expected) - for branch in expected.deleted_branches: + # Check that expired snapshots are actually removed from ancestry + remaining_snapshot_ids = set() + for branch in branches_after: + for snap in self.repo.ancestry(branch=branch): + remaining_snapshot_ids.add(snap.id) + expired_but_remaining = actual & remaining_snapshot_ids + note(expired_but_remaining) + assert ( + not expired_but_remaining + ), f"Snapshots marked as expired but still in ancestry: {expired_but_remaining}" + + assert ( + actual_deleted_branches == expected.deleted_branches + ), f"deleted branches mismatch: actual={actual_deleted_branches}, expected={expected.deleted_branches}" + assert ( + actual_deleted_tags == expected.deleted_tags + ), f"deleted tags mismatch: actual={actual_deleted_tags}, expected={expected.deleted_tags}" + + for branch in actual_deleted_branches: self.maybe_checkout_branch(branch) - @precondition(lambda self: bool(self.commit_times)) + @precondition(lambda self: bool(self.model.commit_times)) @rule( data=st.data(), # we delete based on snapshot created_at time, not flushed_at time @@ -552,7 +657,7 @@ def expire_snapshots( delta=st.integers(min_value=-86400, max_value=86400).filter(lambda x: x != 0), ) def garbage_collect(self, data: st.DataObject, delta: int) -> None: - commit_time = data.draw(st.sampled_from(self.commit_times)) + commit_time = data.draw(st.sampled_from(self.model.commit_times)) older_than = commit_time + datetime.timedelta(seconds=delta) note( f"running garbage_collect for {older_than=!r}, ({commit_time=!r}, {delta=!r})" @@ -573,6 +678,7 @@ def garbage_collect(self, data: st.DataObject, delta: int) -> None: self.model.checkout_branch(DEFAULT_BRANCH) def check_commit(self, commit: str) -> None: + # utility function, not an invariant assume(commit in self.model.commits) note(f"Checking {commit=!r}") expected = self.model.commits[commit] @@ -583,6 +689,15 @@ def check_commit(self, commit: str) -> None: assert actual.written_at == expected.written_at @invariant() + def checks(self) -> None: + # this method only exists to reduce verbosity of hypothesis output + # It cannot be called `check_invariants` because that clashes + # with an existing method on the superclass + self.check_list_prefix_from_root() + self.check_tags() + self.check_branches() + self.check_ancestry() + def check_list_prefix_from_root(self) -> None: model_list = self.model.list_prefix("") ice_list = self.sync_store.list_prefix("") @@ -602,7 +717,6 @@ def check_list_prefix_from_root(self) -> None: np.testing.assert_allclose(actual_fv, expected_fv) assert actual == expected - @invariant() def check_tags(self) -> None: expected_tags = self.model.tags actual_tags = { @@ -611,12 +725,10 @@ def check_tags(self) -> None: } assert actual_tags == expected_tags - @invariant() def check_branches(self) -> None: repo_branches = {k: self.repo.lookup_branch(k) for k in self.repo.list_branches()} - assert self.model.branches == repo_branches + assert self.model.branch_heads == repo_branches - @invariant() def check_ancestry(self) -> None: for branch in self.repo.list_branches(): ancestry = list(self.repo.ancestry(branch=branch))