Skip to content

Commit 7742139

Browse files
fixed review comments
1 parent 4dfa903 commit 7742139

File tree

2 files changed

+108
-46
lines changed

2 files changed

+108
-46
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,44 @@ class OrbaxCheckpoint(MonitorCallback):
7373
inference.
7474
It supports policies for keeping checkpoints and deciding when to save.
7575
76+
Example:
77+
78+
```python
79+
model.compile(loss=..., optimizer=...,
80+
metrics=['accuracy'])
81+
82+
EPOCHS = 10
83+
checkpoint_dir = '/tmp/ckpt'
84+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
85+
directory=checkpoint_dir,
86+
monitor='val_accuracy',
87+
mode='max',
88+
save_best_only=True)
89+
90+
# Model is saved at the end of every epoch, if it's the best seen so far.
91+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
92+
93+
# The model can be loaded from a specific checkpoint step as -
94+
checkpoint = keras.callbacks.OrbaxCheckpoint(directory=checkpoint_dir)
95+
checkpoint.load_checkpoint(step=5, model=model) # Load from step 5
96+
97+
# Alternatively, save checkpoints every N batches -
98+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
99+
directory=checkpoint_dir,
100+
save_freq=100) # Save every 100 batches
101+
102+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
103+
104+
# Or use a SaveDecisionPolicy for more control -
105+
from orbax.checkpoint import checkpoint_managers
106+
policy = checkpoint_managers.FixedIntervalPolicy(interval=5)
107+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
108+
directory=checkpoint_dir,
109+
save_decision_policy=policy) # Save every 5 epochs
110+
111+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
112+
```
113+
76114
Args:
77115
directory: string, path to the directory where to save the checkpoints.
78116
monitor: The metric name to monitor (e.g., 'val_loss').
@@ -86,7 +124,7 @@ class OrbaxCheckpoint(MonitorCallback):
86124
keep_period: Integer, keep one checkpoint every `keep_period` saves.
87125
Useful for keeping checkpoints less frequently over long runs.
88126
initial_value_threshold: Floating point initial "best" value for the
89-
monitor, used with `save_best_only`.
127+
monitor, used with `save_best_only`.
90128
save_optimizer_state: Boolean, whether to include optimizer variables
91129
in the checkpoint. Defaults to True.
92130
save_on_background: Boolean, whether to save asynchronously in the
@@ -110,8 +148,9 @@ class OrbaxCheckpoint(MonitorCallback):
110148
during saving. Keys should match composite_state keys (e.g.,
111149
'model_weights', 'optimizer_state'). Defaults to None.
112150
save_decision_policy: orbax.checkpoint.SaveDecisionPolicy object to
113-
control when checkpoints are saved. If provided, overrides the
114-
default save frequency logic. Defaults to None.
151+
control when checkpoints are saved. Currently supports
152+
FixedIntervalPolicy for saving at regular intervals. If provided,
153+
overrides the default save frequency logic. Defaults to None.
115154
save_interval: Integer, save checkpoints every N steps. If provided,
116155
overrides save_freq. Defaults to None.
117156
"""
@@ -166,6 +205,7 @@ def __init__(
166205
self._batches_seen_since_last_saving = 0
167206
self._last_batch_seen = 0
168207
self._current_epoch = 0 # Keep track of epoch
208+
self._total_batches_seen = 0 # Global batch counter for step tracking
169209

170210
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
171211
raise ValueError("Unrecognized save_freq")
@@ -174,10 +214,10 @@ def __init__(
174214
# if provided
175215
should_save_fn = None
176216
if save_decision_policy is not None:
177-
# For now, create a simple should_save_fn that saves every 2 steps
178-
# This is a placeholder - proper integration would require
179-
# PolicyCheckpointInfo
180-
should_save_fn = lambda step, prev_step=None: step % 2 == 0
217+
# When using save_decision_policy, let Orbax handle
218+
# should_save_fn internally
219+
# Don't override should_save_fn
220+
pass
181221
elif save_interval is not None:
182222
# Create should_save_fn that saves every N steps
183223
should_save_fn = (
@@ -199,6 +239,7 @@ def __init__(
199239
enable_background_delete=self.enable_background_delete,
200240
async_options=async_options,
201241
should_save_fn=should_save_fn,
242+
save_decision_policy=save_decision_policy,
202243
)
203244
# Ensure directory exists (only needed on one process in multi-host)
204245
if backend.get_process_index() == 0:
@@ -218,7 +259,14 @@ def _should_save_on_batch(self, batch):
218259
if self.save_freq == "epoch":
219260
return False
220261

221-
self._batches_seen_since_last_saving += 1
262+
if batch <= self._last_batch_seen: # New epoch.
263+
add_batches = batch + 1
264+
else:
265+
add_batches = batch - self._last_batch_seen
266+
self._batches_seen_since_last_saving += add_batches
267+
self._last_batch_seen = batch
268+
self._total_batches_seen += add_batches
269+
222270
if self._batches_seen_since_last_saving >= self.save_freq:
223271
self._batches_seen_since_last_saving = 0
224272
return True
@@ -235,8 +283,8 @@ def _get_current_step(self):
235283
backend.convert_to_numpy(self.model.optimizer.iterations)
236284
)
237285
else:
238-
# Fallback: use batch count
239-
return self._last_batch_seen
286+
# Fallback: use global batch count
287+
return self._total_batches_seen
240288

241289
def _save_checkpoint(self, step, logs=None):
242290
"""Save a checkpoint at the given step."""
@@ -333,8 +381,6 @@ def on_train_batch_end(self, batch, logs=None):
333381
# step
334382
step = self._get_current_step()
335383
self._save_checkpoint(step=step, logs=logs)
336-
# Ensure all processes sync after save operation
337-
self.manager.wait_until_finished()
338384

339385
def on_epoch_end(self, epoch, logs=None):
340386
self._current_epoch = epoch
@@ -343,9 +389,19 @@ def on_epoch_end(self, epoch, logs=None):
343389

344390
should_save = False
345391
if self.save_decision_policy is not None:
346-
# For FixedIntervalPolicy, save every N steps
347-
# This is a simplified implementation
348-
should_save = epoch % 2 == 0 # Save every 2 epochs for the test
392+
# Handle FixedIntervalPolicy by extracting its interval
393+
from orbax.checkpoint import checkpoint_managers
394+
395+
if isinstance(
396+
self.save_decision_policy,
397+
checkpoint_managers.FixedIntervalPolicy,
398+
):
399+
should_save = epoch % self.save_decision_policy.interval == 0
400+
else:
401+
# For other policies, fall back to saving every epoch
402+
# TODO: Implement full support for other SaveDecisionPolicy
403+
# types
404+
should_save = True
349405
elif self.save_interval is not None:
350406
# Save every N epochs
351407
should_save = epoch % self.save_interval == 0
@@ -371,8 +427,6 @@ def on_epoch_end(self, epoch, logs=None):
371427
if should_save:
372428
# Use epoch number as the step for Orbax save
373429
self._save_checkpoint(step=epoch, logs=logs)
374-
# Ensure all processes sync after save operation
375-
self.manager.wait_until_finished()
376430

377431
def on_train_end(self, logs=None):
378432
if self.verbose > 0:

keras/src/callbacks/orbax_checkpoint_test.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,9 @@ def test_checkpoint_transformations(self):
643643

644644
checkpoint_dir = os.path.join(self.temp_dir, "test_transforms")
645645

646+
# Train for one step first to initialize optimizer variables
647+
model.fit(x, y, epochs=1, verbose=0)
648+
646649
# Create save_args that converts float32 to float16
647650
# Note: save_args structure must match composite_state structure (lists)
648651
save_args = {
@@ -652,18 +655,7 @@ def test_checkpoint_transformations(self):
652655
SaveArgs(dtype=np.dtype(np.float16)), # output weights
653656
SaveArgs(dtype=np.dtype(np.float16)), # output bias
654657
],
655-
"optimizer_state": [
656-
None, # iteration count (no change)
657-
None, # learning rate (no change)
658-
None, # momentum vars (no change)
659-
None, # momentum vars (no change)
660-
None, # momentum vars (no change)
661-
None, # momentum vars (no change)
662-
None, # momentum vars (no change)
663-
None, # momentum vars (no change)
664-
None, # momentum vars (no change)
665-
None, # momentum vars (no change)
666-
],
658+
"optimizer_state": [None] * len(model.optimizer.variables),
667659
}
668660

669661
callback = OrbaxCheckpoint(
@@ -672,11 +664,11 @@ def test_checkpoint_transformations(self):
672664
save_transforms=save_args,
673665
)
674666

675-
# Train for a few epochs
676-
model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)
667+
# Train for one more epoch to trigger save
668+
model.fit(x, y, epochs=1, callbacks=[callback], verbose=0)
677669

678670
# Load checkpoint data to verify transformation was applied
679-
checkpoint_data = self._load_checkpoint_data(callback, step=1)
671+
checkpoint_data = self._load_checkpoint_data(callback, step=0)
680672

681673
# Check that model weights were saved in float16
682674
saved_weights = checkpoint_data["model_weights"]
@@ -1503,21 +1495,37 @@ def _load_checkpoint_data_from_manager(self, manager, step):
15031495
except Exception as e:
15041496
self.fail(f"Failed to load checkpoint data: {e}")
15051497

1506-
def _get_state_as_numpy_helper(self, model):
1507-
"""Helper to convert model state to numpy (copied from
1508-
orbax_checkpoint.py)."""
1509-
try:
1510-
import keras
1498+
@pytest.mark.requires_trainable_backend
1499+
def test_save_decision_policy_integration(self):
1500+
"""Test using orbax.checkpoint.SaveDecisionPolicy objects."""
1501+
from orbax.checkpoint import checkpoint_managers
15111502

1512-
model_weights_np = [
1513-
keras.ops.convert_to_numpy(w) for w in model.weights
1514-
]
1515-
optimizer_vars_np = [
1516-
keras.ops.convert_to_numpy(v) for v in model.optimizer.variables
1517-
]
1518-
return model_weights_np, optimizer_vars_np
1519-
except Exception:
1520-
return None, None
1503+
model = self._create_test_model()
1504+
x, y = self._create_dummy_data()
1505+
1506+
checkpoint_dir = os.path.join(self.temp_dir, "test_decision_policy")
1507+
1508+
# Use FixedIntervalPolicy to save every 3 steps
1509+
policy = checkpoint_managers.FixedIntervalPolicy(
1510+
interval=3, # Save every 3 steps
1511+
)
1512+
1513+
callback = OrbaxCheckpoint(
1514+
directory=checkpoint_dir,
1515+
save_decision_policy=policy,
1516+
)
1517+
1518+
# Train for 10 epochs (steps 0-9)
1519+
model.fit(x, y, epochs=10, callbacks=[callback], verbose=0)
1520+
1521+
# Should have saved at steps 0, 3, 6, 9
1522+
all_steps = sorted(callback.manager.all_steps())
1523+
expected_steps = [0, 3, 6, 9]
1524+
self.assertEqual(
1525+
all_steps,
1526+
expected_steps,
1527+
f"Should save at steps {expected_steps}, got {all_steps}",
1528+
)
15211529

15221530
def _load_checkpoint_data(self, callback, step):
15231531
"""Helper method to load raw checkpoint data for testing."""

0 commit comments

Comments
 (0)