Skip to content

Conversation

@amitsrivastava78
Copy link
Collaborator

@amitsrivastava78 amitsrivastava78 commented Oct 22, 2025

Supports following feature

  • Asynchronous Checkpointing
  • Composite Checkpointing
  • Preservation Policies
  • Save Decision Policies

…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Orbax checkpointing into Keras 3.0, providing a robust and flexible mechanism for saving and restoring training progress. The new OrbaxCheckpoint callback offers features like asynchronous saving, customizable save policies, and the ability to save complex states including model weights, optimizer variables, metrics, and data iterator positions. This enhancement significantly improves the reliability and efficiency of training large models, especially in distributed environments, by leveraging Orbax's advanced capabilities.

Highlights

  • New OrbaxCheckpoint Callback: Introduces a new OrbaxCheckpoint callback for Keras 3.0, enabling advanced data-centric saving and restoration of model states.
  • Asynchronous Checkpointing: Supports asynchronous saving of model weights and optimizer states, allowing training to continue without I/O blocking.
  • Comprehensive Checkpointing Features: Includes support for composite checkpointing, preservation policies (e.g., max_to_keep, keep_period), save decision policies (e.g., save_interval), and custom transformations during saving.
  • Distributed Training Support: Adds a get_process_index utility function to the Keras backend, facilitating distributed training setups by identifying the primary process for checkpoint operations across JAX, TensorFlow, and PyTorch.
  • Extensible with Custom Handlers: Exposes advanced Orbax functionalities like CheckpointManager, TypeHandler, and register_type_handler to allow users to define custom serialization logic for complex objects.
  • Iterator State Saving and Restoration: Enables saving and restoring the state of data iterators, crucial for seamless training resumption from a specific point, with backend-specific examples for TensorFlow, JAX, and PyTorch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces OrbaxCheckpoint, a new Keras callback for advanced checkpointing using the Orbax library. This is a significant feature addition that enables asynchronous saving, composite checkpoints, and other powerful capabilities. The implementation is extensive and is supported by a comprehensive suite of tests.

My review has identified several important issues that need attention. There are critical correctness and performance bugs in the main implementation: the batch-based saving logic is flawed, and the asynchronous saving feature is effectively disabled by blocking calls. Additionally, some features are incomplete, and there are minor areas for improvement in the tests to enhance maintainability. I have provided specific suggestions to address these points. After these fixes, this will be a very valuable addition to Keras.

Comment on lines 119 to 141
def __init__(
self,
directory,
monitor="val_loss",
verbose=0,
save_best_only=False,
mode="auto",
save_freq="epoch",
max_to_keep=5,
keep_period=None,
initial_value_threshold=None,
save_optimizer_state=True,
save_on_background=True,
save_metadata=None,
save_data_iterator=None,
save_metrics_state=False,
async_timeout_secs=600,
enable_background_delete=False,
post_finalization_callback=None,
save_transforms=None,
save_decision_policy=None,
save_interval=None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method has 16 arguments, which is quite high. The Keras API design guidelines suggest reconsidering signatures with more than 6-7 arguments.1 While I understand the need to expose Orbax's functionality, it might be worth exploring if some of these could be grouped into a configuration object to improve readability and usability, similar to how ocp.CheckpointManagerOptions is used internally.

Style Guide References

Footnotes

  1. The style guide recommends that functions with more than 6-7 arguments should be re-evaluated for simplification, possibly by breaking them into smaller objects or modular pieces.

@codecov-commenter
Copy link

codecov-commenter commented Oct 22, 2025

Codecov Report

❌ Patch coverage is 79.36508% with 26 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.65%. Comparing base (47fcb39) to head (4d659f4).
⚠️ Report is 39 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/callbacks/orbax_checkpoint.py 78.81% 16 Missing and 9 partials ⚠️
keras/api/_tf_keras/keras/callbacks/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21762      +/-   ##
==========================================
- Coverage   82.69%   82.65%   -0.04%     
==========================================
  Files         573      578       +5     
  Lines       58888    59631     +743     
  Branches     9218     9356     +138     
==========================================
+ Hits        48696    49290     +594     
- Misses       7845     7928      +83     
- Partials     2347     2413      +66     
Flag Coverage Δ
keras 82.48% <78.57%> (-0.02%) ⬇️
keras-jax 63.33% <75.39%> (+0.09%) ⬆️
keras-numpy 57.49% <30.95%> (-0.23%) ⬇️
keras-openvino 34.35% <30.95%> (-0.05%) ⬇️
keras-tensorflow 64.14% <73.80%> (+0.12%) ⬆️
keras-torch 63.63% <73.80%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This checkpointing system has a ton of features!

Quick first pass.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple more comments I forgot.

- Remove conditional export decorator to ensure OrbaxCheckpoint is always available
- Remove unnecessary exception handling in state tree operations
- Update process index check comment for clarity
- Format code to comply with 80-character line limit
- Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values
- Fix long comment line in test file
- Apply code formatting changes
…st handling

- Implement OrbaxCheckpoint callback for async checkpointing with state tree handling
- Add conditional exports for optional orbax-checkpoint dependency
- Use pytest.importorskip for clean optional dependency testing
- Ensure graceful handling when orbax-checkpoint is not installed
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX implementation of def process_id() is missing.

General questions:

  • Does this as-is support all backends?
  • Does this support JAX sharding? I don't see anything related to sharing (which may be normal). What about re-sharding?

- Preserve nested state tree structures instead of flattening for better layer name preservation
- Add backward compatibility for old flattened format checkpoints
- Simplify test class by using self.get_temp_dir() instead of setUp/tearDown
- Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests
- Move process_id function from backend to distribution module
- Update imports to use centralized LazyModule for orbax.checkpoint
- Test across all backends (JAX, TensorFlow, PyTorch) - all passing
@amitsrivastava78 amitsrivastava78 force-pushed the orbax-checkpoint-test-improvements branch from 621f566 to eb7855d Compare November 10, 2025 09:45
…s expected failures

Neural networks are inherently non-deterministic, so pipeline consistency
checks should be skipped rather than fail. Added check_pipeline_consistency
to EXPECTED_FAILED_CHECKS for all sklearn wrapper types.
- Avoid unnecessary numpy conversion in _get_state_tree() for JAX backend
- Preserve JAX arrays during saving instead of converting to numpy
- Maintain cross-backend compatibility with proper loading conversions
- Update async waiting to use CheckpointManager.wait_until_finished()
- Implement AlwaysSavePolicy for reliable save decisions
- Add expected failures for sklearn tests due to neural network non-determinism
@amitsrivastava78 amitsrivastava78 force-pushed the orbax-checkpoint-test-improvements branch from c14c30e to b7a0dff Compare November 11, 2025 05:54
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available
- Fall back to numpy conversion for older JAX versions that don't have record_scalar
- Maintain cross-backend compatibility while avoiding unnecessary conversions
- Update async waiting to use CheckpointManager.wait_until_finished()
- Implement AlwaysSavePolicy for reliable save decisions
- Add expected failures for sklearn tests due to neural network non-determinism
except Exception:
pass # Ignore errors during cleanup

def load_checkpoint(self, step, model=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I'll have to think about where load_checkpoint, load_latest, all_steps should go. This is not how the Keras loading APIs have worked so far.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok one option can be the OrbaxCheckpoint API to follow Keras' traditional separation of concerns, moving loading functionality from the callback class to standalone functions in the saving API.
@keras_export("keras.saving.load_orbax_checkpoint")
def load_orbax_checkpoint(directory, step, model=None)

@keras_export("keras.saving.load_latest_orbax_checkpoint")
def load_latest_orbax_checkpoint(directory, model=None)

@keras_export("keras.saving.list_orbax_checkpoint_steps")
def list_orbax_checkpoint_steps(directory)

load_checkpoint(step, model=None) → move to keras.saving.load_orbax_checkpoint()
load_latest(model=None) → move to keras.saving.load_orbax_checkpoint()
all_steps() → move to keras.saving.list_orbax_checkpoint_steps()

Let me know what you think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also when we removed the features as suggested by you these functions are no longer needed, So in the latest commit these would not be there, let me know if you are ok with this ?

Comment on lines +13 to +43
# Import advanced Orbax functionality directly from the LazyModule
# These will only be available if orbax-checkpoint is installed
if ocp.available:
Checkpointer = ocp.training.Checkpointer
save_pytree = ocp.save_pytree
load_pytree = ocp.load_pytree
preservation_policies = ocp.training.preservation_policies
save_decision_policies = ocp.training.save_decision_policies
_orbax_available = True
else:
Checkpointer = None
save_pytree = None
load_pytree = None
preservation_policies = None
save_decision_policies = None
_orbax_available = False

# Import our OrbaxCheckpoint callback
try:
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint

_orbax_available = _orbax_available and True
except ImportError:
OrbaxCheckpoint = None
_orbax_available = False


@pytest.mark.skipif(
not _orbax_available,
reason="OrbaxCheckpoint requires the 'orbax-checkpoint' package",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove all the ifs and try / except and @pytest.mark.skipif. The unit test has to fail if Orbax is not installed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would cause the CI failure on GPU tests where orbax is not insalled

- Optimize JAX array handling: avoid unnecessary numpy conversions for JAX >= 0.7.0
- Simplify step counting: use _total_batches_seen directly instead of dual mechanisms
- Remove impossible error checks and verbose messages
- Clean up unused Orbax exports that violated import policies
- Update error message for consistency
- All changes maintain backward compatibility and pass tests across JAX/TensorFlow/PyTorch backends
- Remove extra features: save_metadata, save_data_iterator, post_finalization_callback, save_decision_policy, keep_period
- Remove loading methods: load_checkpoint, load_latest, all_steps, _restore_model_state_from_full_tree
- Replace save_optimizer_state/save_metrics_state with save_weights_only parameter
- Add comprehensive test coverage for all remaining functionality
- Maintain async saving and preservation policies as Orbax-specific advantages
- All tests pass across JAX/TensorFlow/PyTorch backends
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants