Skip to content

Commit 9417027

Browse files
Fixed sav decision between keras and orbax
1 parent cd881dd commit 9417027

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,18 @@ def __init__(
154154

155155
# Set up save_decision_policy if not provided
156156
if save_decision_policy is None:
157-
if save_freq == "epoch":
158-
# For epoch-based saving, save every epoch
159-
save_decision_policy = (
160-
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
161-
)
162-
else:
163-
# For batch-based saving, save every save_freq batches
164-
save_decision_policy = (
165-
ocp.training.save_decision_policies.FixedIntervalPolicy(
166-
save_freq
167-
)
168-
)
157+
# Let Keras handle all save decisions - configure Checkpointer
158+
# to save unconditionally when save_pytree/save_pytree_async
159+
# is called
160+
class _AlwaysSavePolicy(
161+
ocp.training.save_decision_policies.SaveDecisionPolicy
162+
):
163+
def should_save(
164+
self, current_step_info, previous_steps=None, context=None
165+
):
166+
return True
167+
168+
save_decision_policy = _AlwaysSavePolicy()
169169

170170
# --- Orbax Checkpointer Setup (V1 API) ---
171171
# Map V0 options to V1 parameters
@@ -281,7 +281,8 @@ def _save_checkpoint(self, step, logs=None):
281281

282282
# --- Save Logic (V1 API) ---
283283
# All processes participate in distributed checkpointing
284-
# No wait loop needed. The Checkpointer handles overlapping saves.
284+
# Checkpointer is configured to save unconditionally when
285+
# save_pytree is called
285286
if self.verbose > 0:
286287
print_msg(
287288
f"OrbaxCheckpoint: Triggering async save for step {step}..."
@@ -360,8 +361,8 @@ def on_epoch_end(self, epoch, logs=None):
360361

361362
if should_save:
362363
# Use epoch number as the step for Orbax save
363-
# The Checkpointer will decide if it *actually* saves
364-
# based on its internal SaveDecisionPolicy.
364+
# Keras has already made the save decision - Checkpointer will
365+
# save unconditionally
365366
self._save_checkpoint(step=epoch, logs=logs)
366367

367368
def on_train_end(self, logs=None):

keras/src/wrappers/sklearn_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,16 @@ def use_floatx(x):
107107
),
108108
"check_supervised_y_2d": "This test assumes reproducibility in fit.",
109109
"check_fit_idempotent": "This test assumes reproducibility in fit.",
110-
"check_pipeline_consistency": "Neural networks are non-deterministic",
111110
},
112111
"SKLearnRegressor": {
113112
"check_parameters_default_constructible": (
114113
"not an issue in sklearn>=1.6"
115114
),
116-
"check_pipeline_consistency": "Neural networks are non-deterministic",
117115
},
118116
"SKLearnTransformer": {
119117
"check_parameters_default_constructible": (
120118
"not an issue in sklearn>=1.6"
121119
),
122-
"check_pipeline_consistency": "Neural networks are non-deterministic",
123-
"check_transformer_data_not_an_array": "Neural networks are "
124-
"non-deterministic",
125-
"check_transformer_general": "Neural networks are non-deterministic",
126120
},
127121
}
128122

0 commit comments

Comments
 (0)