Skip to content

Commit

Permalink
GC potentially left over steps at startup time for local_ckpt_mgr.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685769033
  • Loading branch information
leiyiz authored and Orbax Authors committed Oct 14, 2024
1 parent f4ea9c0 commit 3fa202d
Showing 1 changed file with 18 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,14 @@ class LocalCheckpointOptions:
Ensures checkpoints will only be saved every m steps. Defaults to 10.
max_to_keep:
Specifies the maximum number of local checkpoints to
keep. Older checkpoints are removed. When set, no more than `max_to_keep`
checkpoints will be present at any one time. This option has a slightly
different meaning than it normally does in Orbax: this should be treated
as a hard cap on the number of checkpoints concurrently present, rather
than a threshold beyond which checkpoints start to be deleted.
keep. Older checkpoints are removed. When set, no more than
(`max_to_keep` + 1) checkpoints will be present at any one time.
read_only:
If True, the local checkpoint manager will not save any checkpoints.
"""

save_interval_steps: int = 10
max_to_keep: int = 2
max_to_keep: int = 1
read_only: bool = False

debug_use_full_global_mesh: bool = False
Expand Down Expand Up @@ -467,6 +464,13 @@ def __init__(
self._local_options = options.local
self._steps = None

# Remove steps that might be left over from previous runs.
steps_to_remove = self._get_old_steps_to_remove()
self._checkpoints = [
info for info in self._checkpoints if info.step not in steps_to_remove
]
self._checkpoint_deleter.delete_steps(steps_to_remove)

def local_host_steps(self, read: bool) -> Sequence[int]:
"""Returns steps known to local host."""
# List of steps present in individual host storage.
Expand Down Expand Up @@ -981,8 +985,9 @@ def _get_single_slice_sharding(
),
self._abstract_state,
)
original_single_slice_shardings_tuple = tuple(jax.tree.flatten(
original_single_slice_shardings)[0])
original_single_slice_shardings_tuple = tuple(
jax.tree.flatten(original_single_slice_shardings)[0]
)

if is_restoring_slice:
logging.vlog(
Expand Down Expand Up @@ -1112,9 +1117,11 @@ def create_zeros(shape_dtype_tup):

return jax.tree.unflatten(tree_defs, shared_states)

def _consistent_restore_mesh_to_global_mesh(self, original_state: PyTree,
desired_slice_shardings,
) -> Any:
def _consistent_restore_mesh_to_global_mesh(
self,
original_state: PyTree,
desired_slice_shardings,
) -> Any:
"""Transfers from consistent restore mesh to global mesh."""
logging.info('Transferring from consistent restore mesh to global mesh')

Expand Down

0 comments on commit 3fa202d

Please sign in to comment.