Skip to content

Commit

Permalink
Improve logging by adding jax_process, error logs in threads and more...
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657225817
  • Loading branch information
niketkumar authored and Orbax Authors committed Jul 29, 2024
1 parent 529b25c commit c3c0039
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 33 deletions.
7 changes: 7 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- Fix callsites of handler.async_save to handle returned None.

### Changed
- Improve logging by adding jax_process, error logs in threads and more...


## [0.5.23] - 2024-07-26

### Changed
Expand Down
9 changes: 4 additions & 5 deletions checkpoint/orbax/checkpoint/array_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ async def async_save():
commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type
# Futures are already running, so sequential waiting is equivalent to
# concurrent waiting.
for f in commit_futures:
f.result() # Block on result.
if commit_futures: # May be None.
for f in commit_futures:
f.result() # Block on result.

asyncio.run(async_save())

Expand Down Expand Up @@ -154,9 +155,7 @@ def restore(
path=checkpoint_path,
parent_dir=directory,
skip_deserialize=False,
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(
directory
),
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(directory),
)
restore_type = restore_args.restore_type
if restore_type is None:
Expand Down
62 changes: 53 additions & 9 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __init__(
barrier_sync_key_prefix: Optional[str] = None,
):
logging.info(
'Using timeout: %d secs and primary_host=%s for async checkpoint'
' writes',
'[process=%s] Using timeout: %d secs and primary_host=%s for async'
' checkpoint writes',
multihost.process_index(),
timeout_secs,
primary_host,
)
Expand Down Expand Up @@ -95,24 +96,30 @@ def _thread_func(
unique_operation_id: str,
):
"""Awaits on commit futures and finalizes the checkpoint."""
current_process = multihost.process_index()
# The unique_operation_id allows pre-selecting an identifier to use for the
# barriers in this background thread. If we have multiple background
# threads running concurrently, relying on _module_unique_count can result
# in deadlocks when threads on different processes arrive at the barriers
# in a certain order.
try:
current_process = multihost.process_index()
process_count = jax.process_count()
logging.info(
'Starting commit to storage layer by process: %s', current_process
'[process=%s] Waiting for %s commit threads, directory: %s',
current_process,
len(commit_futures),
directory,
)
thread_start_time = time.time()

# Wait for commit operations to complete.
for future in commit_futures:
future.result()
logging.info(
'Finished committing to storage layer by process: %s', current_process
'[process=%s] Finished %s commit threads, directory: %s',
current_process,
len(commit_futures),
directory,
)
# Log the number of async writes that are in flight. Abuses a duration
# metric as a counter since jax.monitoring only has events and durations.
Expand Down Expand Up @@ -156,7 +163,13 @@ def _thread_func(
)

except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e
msg = (
f'[process={current_process}] Failed to run'
f' {len(commit_futures)} commit threads or the commit callback,'
f' directory: {directory}'
)
logging.exception(msg)
self._exception = ExceptionGroup(msg, [e])

def start_async_commit(
self,
Expand Down Expand Up @@ -190,10 +203,16 @@ def wait_until_finished(self):
if self._thread is not None:
self._thread.join()
self._thread = None
logging.info('Commit thread joined successfully')
logging.info(
'[process=%s] Commit thread joined successfully',
multihost.process_index(),
)

self.check_for_errors()
logging.info('Commit thread error check finished successfully')
logging.info(
'[process=%s] Commit thread error check finished successfully',
multihost.process_index(),
)


class AsyncCheckpointer(checkpointer.Checkpointer):
Expand Down Expand Up @@ -304,7 +323,13 @@ def save(
raise ValueError(f'Destination {directory} already exists.')
tmpdir = self.create_temporary_path(directory)

logging.info('Async saving checkpoint to %s.', directory)
logging.info(
'[process=%s] Async saving checkpoint to tmp dir=%s, eventually to'
' final dir=%s.',
multihost.process_index(),
tmpdir.get(),
tmpdir.get_final(),
)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
Expand All @@ -318,9 +343,28 @@ def save(

# Directory is the final directory.
def _callback() -> None:
logging.info(
'[process=%s] Async Save Callback[1/3]: Finalizing Handler: %s on %s',
multihost.process_index(),
self._handler,
tmpdir.get(),
)
self._handler.finalize(tmpdir.get())
logging.info(
'[process=%s] Async Save Callback[2/3]: Running'
' post_finalization_callback: %s on %s',
multihost.process_index(),
self._post_finalization_callback,
tmpdir.get_final(),
)
if self._post_finalization_callback is not None:
self._post_finalization_callback()
logging.info(
'[process=%s] Async Save Callback[3/3]: Finalizing checkpoint'
' directory: %s',
multihost.process_index(),
tmpdir.get(),
)
_on_commit_callback(tmpdir, checkpoint_start_time)

self._async_manager.start_async_commit(
Expand Down
6 changes: 5 additions & 1 deletion checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,11 @@ async def async_save(
_maybe_raise_reserved_item_error(item_name)
handler = self._get_or_set_handler(item_name, arg)
if isinstance(handler, AsyncCheckpointHandler):
futures.extend(await handler.async_save(item_directory.get(), args=arg))
commit_futures = await handler.async_save(
item_directory.get(), args=arg
)
if commit_futures is not None:
futures.extend(commit_futures)
else:
handler.save(item_directory.get(), args=arg)
return futures
Expand Down
24 changes: 19 additions & 5 deletions checkpoint/orbax/checkpoint/multihost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,19 @@ def get_barrier_sync_fn(

def _fn(*, key: str, timeout_ms: int) -> None:
key = _unique_barrier_key(key)
logging.info('Waiting at barrier: %s', key)
logging.info('[process=%s] Waiting at barrier: %s', process_index(), key)
if processes is None:
client.wait_at_barrier(key, timeout_ms)
else:
logging.debug('Current process: %d', process_index())
logging.debug('Barrier processes: %s', barrier_processes)
logging.debug(
'[process=%s] Barrier processes: %s',
process_index(),
barrier_processes,
)
client.wait_at_barrier(key, timeout_ms, process_ids=barrier_processes)
logging.info('Done waiting at barrier: %s', key)
logging.info(
'[process=%s] Done waiting at barrier: %s', process_index(), key
)

return _fn

Expand Down Expand Up @@ -239,7 +244,16 @@ def sync_global_processes(

def reached_preemption(step: int) -> bool:
"""Returns True if a preemption sync point has been reached."""
return multihost_utils.reached_preemption_sync_point(step)
preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(
step
)
if preemption_sync_point_reached:
logging.warning(
'[process=%s] Reached preemption sync point, step=%s',
process_index(),
step,
)
return preemption_sync_point_reached


def is_primary_host(primary_host: Optional[int]):
Expand Down
33 changes: 25 additions & 8 deletions checkpoint/orbax/checkpoint/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _create_tmp_directory(
path_permission_mode: Path permission mode for the temp directory. e.g.
0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported.
path is supported.
checkpoint_metadata_store: optional `CheckpointMetadataStore` instance. If
present then it is used to create `CheckpointMetadata` with current
timestamp.
Expand Down Expand Up @@ -177,16 +177,23 @@ def _create_tmp_directory(
if tmp_dir.exists():
if step_lib.is_tmp_checkpoint(tmp_dir):
logging.warning(
'Attempted to create temporary directory %s which already exists.'
' Removing existing directory since it is not finalized.',
'[process=%s] Attempted to create temporary directory %s which'
' already exists. Removing existing directory since it is not'
' finalized.',
multihost.process_index(),
tmp_dir,
)
tmp_dir.rmtree()
else:
raise FileExistsError(
f'Attempted to create temporary directory {tmp_dir} which already'
' exists.'
' exists but could not be resolved as a checkpoint tmp directory.'
)
logging.info(
'[process=%s] Creating tmp directory %s',
multihost.process_index(),
tmp_dir,
)
tmp_dir.mkdir(parents=True, exist_ok=False, mode=path_permission_mode)
if checkpoint_metadata_store is not None:
checkpoint_metadata_store.write(
Expand Down Expand Up @@ -313,7 +320,6 @@ def create(
Raises:
FileExistsError: if tmp directory already exists.
"""
logging.info('Creating tmp directory %s', self._tmp_path)
mode = step_lib.WORLD_READABLE_MODE # pylint: disable=unused-variable
mode = (
file_options.path_permission_mode or self._path_permission_mode or mode
Expand All @@ -333,7 +339,12 @@ def finalize(self):
Updates checkpoint metadata with commit_timestamp_nsecs.
"""
logging.info('Renaming %s to %s', self._tmp_path, self._final_path)
logging.info(
'[process=%s] Renaming %s to %s',
multihost.process_index(),
self._tmp_path,
self._final_path,
)
if self._checkpoint_metadata_store:
self._checkpoint_metadata_store.wait_until_finished()
self._checkpoint_metadata_store.update(
Expand Down Expand Up @@ -453,7 +464,9 @@ def finalize(self):
Updates checkpoint metadata with commit_timestamp_nsecs.
"""
logging.info('Finalizing %s', self._tmp_path)
logging.info(
'[process=%s] Finalizing %s', multihost.process_index(), self._tmp_path
)
if self._checkpoint_metadata_store:
self._checkpoint_metadata_store.wait_until_finished()
self._checkpoint_metadata_store.update(
Expand Down Expand Up @@ -492,4 +505,8 @@ def on_commit_callback(
"""
tmp_dir.finalize()
step_lib.record_saved_duration(checkpoint_start_time)
logging.info('Finished saving checkpoint to `%s`.', tmp_dir.get_final())
logging.info(
'[process=%s] Committed checkpoint save to `%s`.',
multihost.process_index(),
tmp_dir.get_final(),
)
11 changes: 6 additions & 5 deletions checkpoint/orbax/checkpoint/random_key_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ async def async_save():
commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type
# Futures are already running, so sequential waiting is equivalent to
# concurrent waiting.
for f in commit_futures:
f.result() # Block on result.
if commit_futures: # May be None.
for f in commit_futures:
f.result() # Block on result.

asyncio.run(async_save())

Expand All @@ -158,9 +159,7 @@ def restore(
}),
)

return self.post_restore(
result[self._key_name], result[self._key_metadata]
)
return self.post_restore(result[self._key_name], result[self._key_metadata])

def finalize(self, directory: epath.Path):
self._handler.finalize(directory)
Expand Down Expand Up @@ -284,11 +283,13 @@ class NumpyRandomKeySaveArgs(CheckpointArgs):
Attributes:
item (required): a Numpy random key in legacy or nonlegacy format
"""

item: NumpyRandomKeyType


@register_with_handler(NumpyRandomKeyCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class NumpyRandomKeyRestoreArgs(CheckpointArgs):
"""Numpy random key restore args."""

pass

0 comments on commit c3c0039

Please sign in to comment.