Skip to content

Commit

Permalink
Improve logging by adding jax_process, error logs in threads and more...
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657225817
  • Loading branch information
niketkumar authored and Orbax Authors committed Jul 29, 2024
1 parent 4bc8308 commit 25a5dd6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 12 deletions.
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fix callsites of handler.async_save to handle returned None.

### Changed
- Improve logging by adding jax_process, error logs in threads and more...


## [0.5.23] - 2024-07-26

Expand Down
31 changes: 28 additions & 3 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def _thread_func(
unique_operation_id: str,
):
"""Awaits on commit futures and finalizes the checkpoint."""
current_process = multihost.process_index()
# The unique_operation_id allows pre-selecting an identifier to use for the
# barriers in this background thread. If we have multiple background
# threads running concurrently, relying on _module_unique_count can result
# in deadlocks when threads on different processes arrive at the barriers
# in a certain order.
try:
current_process = multihost.process_index()
process_count = jax.process_count()
logging.info(
'Starting commit to storage layer by process: %s', current_process
Expand Down Expand Up @@ -156,7 +156,13 @@ def _thread_func(
)

except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e
msg = (
f'[process={current_process}] Failed to run'
f' {len(commit_futures)} commit threads or the commit callback,'
f' directory: {directory}'
)
logging.exception(msg)
self._exception = ExceptionGroup(msg, [e])

def start_async_commit(
self,
Expand Down Expand Up @@ -304,7 +310,11 @@ def save(
raise ValueError(f'Destination {directory} already exists.')
tmpdir = self.create_temporary_path(directory)

logging.info('Async saving checkpoint to %s.', directory)
logging.info(
'Async saving checkpoint to tmp dir=%s, eventually to final dir=%s.',
tmpdir.get(),
tmpdir.get_final(),
)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
Expand All @@ -318,9 +328,24 @@ def save(

# Directory is the final directory.
def _callback() -> None:
logging.info(
'Async Save Callback [1/3]: Finalizing Handler: %s on %s',
self._handler,
tmpdir.get(),
)
self._handler.finalize(tmpdir.get())
logging.info(
'Async Save Callback [2/3]: Running'
' post_finalization_callback: %s on %s',
self._post_finalization_callback,
tmpdir.get_final(),
)
if self._post_finalization_callback is not None:
self._post_finalization_callback()
logging.info(
'Async Save Callback [3/3]: Finalizing checkpoint directory: %s',
tmpdir.get(),
)
_on_commit_callback(tmpdir, checkpoint_start_time)

self._async_manager.start_async_commit(
Expand Down
24 changes: 19 additions & 5 deletions checkpoint/orbax/checkpoint/multihost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,19 @@ def get_barrier_sync_fn(

def _fn(*, key: str, timeout_ms: int) -> None:
key = _unique_barrier_key(key)
logging.info('Waiting at barrier: %s', key)
logging.info('[process=%s] Waiting at barrier: %s', process_index(), key)
if processes is None:
client.wait_at_barrier(key, timeout_ms)
else:
logging.debug('Current process: %d', process_index())
logging.debug('Barrier processes: %s', barrier_processes)
logging.debug(
'[process=%s] Barrier processes: %s',
process_index(),
barrier_processes,
)
client.wait_at_barrier(key, timeout_ms, process_ids=barrier_processes)
logging.info('Done waiting at barrier: %s', key)
logging.info(
'[process=%s] Done waiting at barrier: %s', process_index(), key
)

return _fn

Expand Down Expand Up @@ -239,7 +244,16 @@ def sync_global_processes(

def reached_preemption(step: int) -> bool:
"""Returns True if a preemption sync point has been reached."""
return multihost_utils.reached_preemption_sync_point(step)
preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(
step
)
if preemption_sync_point_reached:
logging.warning(
'[process=%s] Reached preemption sync point, step=%s',
process_index(),
step,
)
return preemption_sync_point_reached


def is_primary_host(primary_host: Optional[int]):
Expand Down
8 changes: 4 additions & 4 deletions checkpoint/orbax/checkpoint/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _create_tmp_directory(
path_permission_mode: Path permission mode for the temp directory. e.g.
0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported.
path is supported.
checkpoint_metadata_store: optional `CheckpointMetadataStore` instance. If
present then it is used to create `CheckpointMetadata` with current
timestamp.
Expand Down Expand Up @@ -185,8 +185,9 @@ def _create_tmp_directory(
else:
raise FileExistsError(
f'Attempted to create temporary directory {tmp_dir} which already'
' exists.'
' exists but could not be resolved as a checkpoint tmp directory.'
)
logging.info('Creating tmp directory %s', tmp_dir)
tmp_dir.mkdir(parents=True, exist_ok=False, mode=path_permission_mode)
if checkpoint_metadata_store is not None:
checkpoint_metadata_store.write(
Expand Down Expand Up @@ -313,7 +314,6 @@ def create(
Raises:
FileExistsError: if tmp directory already exists.
"""
logging.info('Creating tmp directory %s', self._tmp_path)
mode = step_lib.WORLD_READABLE_MODE # pylint: disable=unused-variable
mode = (
file_options.path_permission_mode or self._path_permission_mode or mode
Expand Down Expand Up @@ -492,4 +492,4 @@ def on_commit_callback(
"""
tmp_dir.finalize()
step_lib.record_saved_duration(checkpoint_start_time)
logging.info('Finished saving checkpoint to `%s`.', tmp_dir.get_final())
logging.info('Committed checkpoint save to `%s`.', tmp_dir.get_final())

0 comments on commit 25a5dd6

Please sign in to comment.