Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve logging by adding jax_process, error logs in threads and more... #1028

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fix callsites of handler.async_save to handle returned None.

### Changed
- Improve logging by adding jax_process, error logs in threads and more...


## [0.5.23] - 2024-07-26

Expand Down
29 changes: 27 additions & 2 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def _thread_func(
unique_operation_id: str,
):
"""Awaits on commit futures and finalizes the checkpoint."""
current_process = multihost.process_index()
# The unique_operation_id allows pre-selecting an identifier to use for the
# barriers in this background thread. If we have multiple background
# threads running concurrently, relying on _module_unique_count can result
# in deadlocks when threads on different processes arrive at the barriers
# in a certain order.
try:
current_process = multihost.process_index()
process_count = jax.process_count()
logging.info(
'Starting commit to storage layer by process: %s', current_process
Expand Down Expand Up @@ -156,6 +156,12 @@ def _thread_func(
)

except Exception as e: # pylint: disable=broad-exception-caught
msg = (
f'[process={current_process}] Failed to run'
f' {len(commit_futures)} commit threads or the commit callback,'
f' directory: {directory}'
)
logging.error(msg, exc_info=True)
self._exception = e

def start_async_commit(
Expand Down Expand Up @@ -304,7 +310,11 @@ def save(
raise ValueError(f'Destination {directory} already exists.')
tmpdir = self.create_temporary_path(directory)

logging.info('Async saving checkpoint to %s.', directory)
logging.info(
'Async saving checkpoint to tmp dir=%s, eventually to final dir=%s.',
tmpdir.get(),
tmpdir.get_final(),
)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
Expand All @@ -318,9 +328,24 @@ def save(

# Directory is the final directory.
def _callback() -> None:
logging.info(
'Async Save Callback [1/3]: Finalizing Handler: %s on %s',
self._handler,
tmpdir.get(),
)
self._handler.finalize(tmpdir.get())
logging.info(
'Async Save Callback [2/3]: Running'
' post_finalization_callback: %s on %s',
self._post_finalization_callback,
tmpdir.get_final(),
)
if self._post_finalization_callback is not None:
self._post_finalization_callback()
logging.info(
'Async Save Callback [3/3]: Finalizing checkpoint directory: %s',
tmpdir.get(),
)
_on_commit_callback(tmpdir, checkpoint_start_time)

self._async_manager.start_async_commit(
Expand Down
24 changes: 19 additions & 5 deletions checkpoint/orbax/checkpoint/multihost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,19 @@ def get_barrier_sync_fn(

def _fn(*, key: str, timeout_ms: int) -> None:
key = _unique_barrier_key(key)
logging.info('Waiting at barrier: %s', key)
logging.info('[process=%s] Waiting at barrier: %s', process_index(), key)
if processes is None:
client.wait_at_barrier(key, timeout_ms)
else:
logging.debug('Current process: %d', process_index())
logging.debug('Barrier processes: %s', barrier_processes)
logging.debug(
'[process=%s] Barrier processes: %s',
process_index(),
barrier_processes,
)
client.wait_at_barrier(key, timeout_ms, process_ids=barrier_processes)
logging.info('Done waiting at barrier: %s', key)
logging.info(
'[process=%s] Done waiting at barrier: %s', process_index(), key
)

return _fn

Expand Down Expand Up @@ -239,7 +244,16 @@ def sync_global_processes(

def reached_preemption(step: int) -> bool:
"""Returns True if a preemption sync point has been reached."""
return multihost_utils.reached_preemption_sync_point(step)
preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(
step
)
if preemption_sync_point_reached:
logging.warning(
'[process=%s] Reached preemption sync point, step=%s',
process_index(),
step,
)
return preemption_sync_point_reached


def is_primary_host(primary_host: Optional[int]):
Expand Down
6 changes: 3 additions & 3 deletions checkpoint/orbax/checkpoint/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ def _create_tmp_directory(
else:
raise FileExistsError(
f'Attempted to create temporary directory {tmp_dir} which already'
' exists.'
' exists but could not be resolved as a checkpoint tmp directory.'
)
logging.info('Creating tmp directory %s', tmp_dir)
tmp_dir.mkdir(parents=True, exist_ok=False, mode=path_permission_mode)
if checkpoint_metadata_store is not None:
checkpoint_metadata_store.write(
Expand Down Expand Up @@ -313,7 +314,6 @@ def create(
Raises:
FileExistsError: if tmp directory already exists.
"""
logging.info('Creating tmp directory %s', self._tmp_path)
mode = step_lib.WORLD_READABLE_MODE # pylint: disable=unused-variable
mode = (
file_options.path_permission_mode or self._path_permission_mode or mode
Expand Down Expand Up @@ -492,4 +492,4 @@ def on_commit_callback(
"""
tmp_dir.finalize()
step_lib.record_saved_duration(checkpoint_start_time)
logging.info('Finished saving checkpoint to `%s`.', tmp_dir.get_final())
logging.info('Committed checkpoint save to `%s`.', tmp_dir.get_final())
Loading