@@ -73,6 +73,44 @@ class OrbaxCheckpoint(MonitorCallback):
7373 inference.
7474 It supports policies for keeping checkpoints and deciding when to save.
7575
76+ Example:
77+
78+ ```python
79+ model.compile(loss=..., optimizer=...,
80+ metrics=['accuracy'])
81+
82+ EPOCHS = 10
83+ checkpoint_dir = '/tmp/ckpt'
84+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
85+ directory=checkpoint_dir,
86+ monitor='val_accuracy',
87+ mode='max',
88+ save_best_only=True)
89+
90+ # Model is saved at the end of every epoch, if it's the best seen so far.
91+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
92+
93+ # The model can be loaded from a specific checkpoint step as -
94+ checkpoint = keras.callbacks.OrbaxCheckpoint(directory=checkpoint_dir)
95+ checkpoint.load_checkpoint(step=5, model=model) # Load from step 5
96+
97+ # Alternatively, save checkpoints every N batches -
98+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
99+ directory=checkpoint_dir,
100+ save_freq=100) # Save every 100 batches
101+
102+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
103+
104+ # Or use a SaveDecisionPolicy for more control -
105+ from orbax.checkpoint import checkpoint_managers
106+ policy = checkpoint_managers.FixedIntervalPolicy(interval=5)
107+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
108+ directory=checkpoint_dir,
109+ save_decision_policy=policy) # Save every 5 epochs
110+
111+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
112+ ```
113+
76114 Args:
77115 directory: string, path to the directory where to save the checkpoints.
78116 monitor: The metric name to monitor (e.g., 'val_loss').
@@ -86,7 +124,7 @@ class OrbaxCheckpoint(MonitorCallback):
86124 keep_period: Integer, keep one checkpoint every `keep_period` saves.
87125 Useful for keeping checkpoints less frequently over long runs.
88126 initial_value_threshold: Floating point initial "best" value for the
89- monitor, used with `save_best_only`.
127+ monitor, used with `save_best_only`.
90128 save_optimizer_state: Boolean, whether to include optimizer variables
91129 in the checkpoint. Defaults to True.
92130 save_on_background: Boolean, whether to save asynchronously in the
@@ -110,8 +148,9 @@ class OrbaxCheckpoint(MonitorCallback):
110148 during saving. Keys should match composite_state keys (e.g.,
111149 'model_weights', 'optimizer_state'). Defaults to None.
112150 save_decision_policy: orbax.checkpoint.SaveDecisionPolicy object to
113- control when checkpoints are saved. If provided, overrides the
114- default save frequency logic. Defaults to None.
151+ control when checkpoints are saved. Currently supports
152+ FixedIntervalPolicy for saving at regular intervals. If provided,
153+ overrides the default save frequency logic. Defaults to None.
115154 save_interval: Integer, save checkpoints every N steps. If provided,
116155 overrides save_freq. Defaults to None.
117156 """
@@ -166,6 +205,7 @@ def __init__(
166205 self ._batches_seen_since_last_saving = 0
167206 self ._last_batch_seen = 0
168207 self ._current_epoch = 0 # Keep track of epoch
208+ self ._total_batches_seen = 0 # Global batch counter for step tracking
169209
170210 if self .save_freq != "epoch" and not isinstance (self .save_freq , int ):
171211 raise ValueError ("Unrecognized save_freq" )
@@ -174,10 +214,10 @@ def __init__(
174214 # if provided
175215 should_save_fn = None
176216 if save_decision_policy is not None :
177- # For now, create a simple should_save_fn that saves every 2 steps
178- # This is a placeholder - proper integration would require
179- # PolicyCheckpointInfo
180- should_save_fn = lambda step , prev_step = None : step % 2 == 0
217+ # When using save_decision_policy, let Orbax handle
218+ # should_save_fn internally
219+ # Don't override should_save_fn
220+ pass
181221 elif save_interval is not None :
182222 # Create should_save_fn that saves every N steps
183223 should_save_fn = (
@@ -199,6 +239,7 @@ def __init__(
199239 enable_background_delete = self .enable_background_delete ,
200240 async_options = async_options ,
201241 should_save_fn = should_save_fn ,
242+ save_decision_policy = save_decision_policy ,
202243 )
203244 # Ensure directory exists (only needed on one process in multi-host)
204245 if backend .get_process_index () == 0 :
@@ -218,7 +259,14 @@ def _should_save_on_batch(self, batch):
218259 if self .save_freq == "epoch" :
219260 return False
220261
221- self ._batches_seen_since_last_saving += 1
262+ if batch <= self ._last_batch_seen : # New epoch.
263+ add_batches = batch + 1
264+ else :
265+ add_batches = batch - self ._last_batch_seen
266+ self ._batches_seen_since_last_saving += add_batches
267+ self ._last_batch_seen = batch
268+ self ._total_batches_seen += add_batches
269+
222270 if self ._batches_seen_since_last_saving >= self .save_freq :
223271 self ._batches_seen_since_last_saving = 0
224272 return True
@@ -235,8 +283,8 @@ def _get_current_step(self):
235283 backend .convert_to_numpy (self .model .optimizer .iterations )
236284 )
237285 else :
238- # Fallback: use batch count
239- return self ._last_batch_seen
286+ # Fallback: use global batch count
287+ return self ._total_batches_seen
240288
241289 def _save_checkpoint (self , step , logs = None ):
242290 """Save a checkpoint at the given step."""
@@ -333,8 +381,6 @@ def on_train_batch_end(self, batch, logs=None):
333381 # step
334382 step = self ._get_current_step ()
335383 self ._save_checkpoint (step = step , logs = logs )
336- # Ensure all processes sync after save operation
337- self .manager .wait_until_finished ()
338384
339385 def on_epoch_end (self , epoch , logs = None ):
340386 self ._current_epoch = epoch
@@ -343,9 +389,19 @@ def on_epoch_end(self, epoch, logs=None):
343389
344390 should_save = False
345391 if self .save_decision_policy is not None :
346- # For FixedIntervalPolicy, save every N steps
347- # This is a simplified implementation
348- should_save = epoch % 2 == 0 # Save every 2 epochs for the test
392+ # Handle FixedIntervalPolicy by extracting its interval
393+ from orbax .checkpoint import checkpoint_managers
394+
395+ if isinstance (
396+ self .save_decision_policy ,
397+ checkpoint_managers .FixedIntervalPolicy ,
398+ ):
399+ should_save = epoch % self .save_decision_policy .interval == 0
400+ else :
401+ # For other policies, fall back to saving every epoch
402+ # TODO: Implement full support for other SaveDecisionPolicy
403+ # types
404+ should_save = True
349405 elif self .save_interval is not None :
350406 # Save every N epochs
351407 should_save = epoch % self .save_interval == 0
@@ -371,8 +427,6 @@ def on_epoch_end(self, epoch, logs=None):
371427 if should_save :
372428 # Use epoch number as the step for Orbax save
373429 self ._save_checkpoint (step = epoch , logs = logs )
374- # Ensure all processes sync after save operation
375- self .manager .wait_until_finished ()
376430
377431 def on_train_end (self , logs = None ):
378432 if self .verbose > 0 :
0 commit comments