Skip to content

Commit 61bd5e6

Browse files
Fix code formatting and remove unused variable
- Remove unused 'result' variable in _reconstruct_state_tree_with_values - Fix long comment line in test file - Apply code formatting changes
1 parent 822396f commit 61bd5e6

File tree

10 files changed

+34
-36
lines changed

10 files changed

+34
-36
lines changed

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from keras.src.backend.jax.core import shape
2525
from keras.src.backend.jax.core import stop_gradient
2626
from keras.src.backend.jax.core import vectorized_map
27+
from keras.src.backend.jax.distribution_lib import process_id
2728
from keras.src.backend.jax.rnn import cudnn_ok
2829
from keras.src.backend.jax.rnn import gru
2930
from keras.src.backend.jax.rnn import lstm
3031
from keras.src.backend.jax.rnn import rnn
31-
from keras.src.backend.jax.distribution_lib import process_id

keras/src/backend/numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from keras.src.backend.numpy.core import random_seed_dtype
2121
from keras.src.backend.numpy.core import shape
2222
from keras.src.backend.numpy.core import vectorized_map
23+
from keras.src.backend.numpy.distribution_lib import process_id
2324
from keras.src.backend.numpy.rnn import cudnn_ok
2425
from keras.src.backend.numpy.rnn import gru
2526
from keras.src.backend.numpy.rnn import lstm
2627
from keras.src.backend.numpy.rnn import rnn
27-
from keras.src.backend.numpy.distribution_lib import process_id

keras/src/backend/numpy/distribution_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
def process_id():
55
"""Return the current process ID for the distribution setting."""
6-
return 0
6+
return 0

keras/src/backend/openvino/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from keras.src.backend.openvino.core import random_seed_dtype
2121
from keras.src.backend.openvino.core import shape
2222
from keras.src.backend.openvino.core import vectorized_map
23+
from keras.src.backend.openvino.distribution_lib import process_id
2324
from keras.src.backend.openvino.rnn import cudnn_ok
2425
from keras.src.backend.openvino.rnn import gru
2526
from keras.src.backend.openvino.rnn import lstm
2627
from keras.src.backend.openvino.rnn import rnn
27-
from keras.src.backend.openvino.distribution_lib import process_id

keras/src/backend/openvino/distribution_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
def process_id():
55
"""Return the current process ID for the distribution setting."""
6-
return 0
6+
return 0

keras/src/backend/tensorflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from keras.src.backend.tensorflow.core import shape
2424
from keras.src.backend.tensorflow.core import stop_gradient
2525
from keras.src.backend.tensorflow.core import vectorized_map
26+
from keras.src.backend.tensorflow.distribution_lib import process_id
2627
from keras.src.backend.tensorflow.rnn import cudnn_ok
2728
from keras.src.backend.tensorflow.rnn import gru
2829
from keras.src.backend.tensorflow.rnn import lstm
2930
from keras.src.backend.tensorflow.rnn import rnn
30-
from keras.src.backend.tensorflow.distribution_lib import process_id

keras/src/backend/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
from keras.src.backend.torch.core import stop_gradient
4040
from keras.src.backend.torch.core import to_torch_dtype
4141
from keras.src.backend.torch.core import vectorized_map
42+
from keras.src.backend.torch.distribution_lib import process_id
4243
from keras.src.backend.torch.rnn import cudnn_ok
4344
from keras.src.backend.torch.rnn import gru
4445
from keras.src.backend.torch.rnn import lstm
4546
from keras.src.backend.torch.rnn import rnn
46-
from keras.src.backend.torch.distribution_lib import process_id

keras/src/backend/torch/distribution_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ def process_id():
1010
return dist.get_rank()
1111
return 0
1212
except (ImportError, AttributeError):
13-
return 0
13+
return 0

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
def _get_state_tree(model):
4242
"""Get the complete model state as a nested tree structure."""
4343
state_tree = model.get_state_tree(value_format="numpy_array")
44-
44+
4545
# Convert numpy scalar types to Python types for Orbax compatibility
4646
def convert_scalars(obj):
4747
if isinstance(obj, np.ndarray) and obj.ndim == 0:
@@ -54,29 +54,30 @@ def convert_scalars(obj):
5454
return {k: convert_scalars(v) for k, v in obj.items()}
5555
else:
5656
return obj
57-
57+
5858
return convert_scalars(state_tree)
5959

6060

6161
def _flatten_state_tree_values(state_tree):
6262
"""Flatten nested state tree into a list of values in consistent order."""
6363
values = []
64+
6465
def _flatten(obj):
6566
if isinstance(obj, dict):
6667
for key in sorted(obj.keys()): # Sort for consistent ordering
6768
_flatten(obj[key])
6869
else:
6970
# Save any non-dict value (numpy arrays, lists, scalars, etc.)
7071
values.append(obj)
72+
7173
_flatten(state_tree)
7274
return values
7375

7476

7577
def _reconstruct_state_tree_with_values(structure, values):
7678
"""Reconstruct state tree structure with provided values."""
77-
result = {}
7879
value_iter = iter(values)
79-
80+
8081
def _reconstruct(obj):
8182
if isinstance(obj, dict):
8283
new_dict = {}
@@ -109,7 +110,7 @@ def _reconstruct(obj):
109110
return np.array(value, dtype=obj.dtype).reshape(obj.shape)
110111
else:
111112
return value
112-
113+
113114
return _reconstruct(structure)
114115

115116

@@ -128,15 +129,10 @@ def _restore_legacy_format(
128129
target_model.weights[i].assign(weight_tensor)
129130

130131
# Restore optimizer state if available
131-
if (
132-
"optimizer_state" in checkpoint_data
133-
and save_optimizer_state
134-
):
132+
if "optimizer_state" in checkpoint_data and save_optimizer_state:
135133
optimizer_vars_np = checkpoint_data["optimizer_state"]
136134
# Only restore if the variable counts match
137-
if len(optimizer_vars_np) == len(
138-
target_model.optimizer.variables
139-
):
135+
if len(optimizer_vars_np) == len(target_model.optimizer.variables):
140136
# Convert NumPy arrays back to backend tensors and assign to
141137
# optimizer
142138
for i, var_np in enumerate(optimizer_vars_np):
@@ -406,14 +402,14 @@ def _save_checkpoint(self, step, logs=None):
406402
trainable_values = _flatten_state_tree_values(
407403
state_tree["trainable_variables"]
408404
)
409-
405+
410406
# Save optimizer and metrics state if requested
411407
optimizer_values = None
412408
if self.save_optimizer_state and "optimizer_variables" in state_tree:
413409
optimizer_values = _flatten_state_tree_values(
414410
state_tree["optimizer_variables"]
415411
)
416-
412+
417413
metrics_values = None
418414
if self.save_metrics_state and "metrics_variables" in state_tree:
419415
metrics_values = _flatten_state_tree_values(
@@ -423,7 +419,7 @@ def _save_checkpoint(self, step, logs=None):
423419
composite_state = {
424420
"model_weights": trainable_values,
425421
}
426-
422+
427423
if optimizer_values is not None:
428424
composite_state["optimizer_state"] = optimizer_values
429425
if metrics_values is not None:
@@ -611,8 +607,9 @@ def _restore_model_state(self, checkpoint_data, model=None):
611607
target_model = model if model is not None else self.model
612608

613609
# Check if this is the new flattened format
614-
if ("model_weights" in checkpoint_data and
615-
isinstance(checkpoint_data["model_weights"], list)):
610+
if "model_weights" in checkpoint_data and isinstance(
611+
checkpoint_data["model_weights"], list
612+
):
616613
# New format: flattened values
617614
return self._restore_from_flattened_values(
618615
checkpoint_data, target_model
@@ -625,8 +622,10 @@ def _restore_model_state(self, checkpoint_data, model=None):
625622
else:
626623
# Fallback to legacy format
627624
_restore_legacy_format(
628-
checkpoint_data, target_model, self.save_optimizer_state,
629-
self.save_metrics_state
625+
checkpoint_data,
626+
target_model,
627+
self.save_optimizer_state,
628+
self.save_metrics_state,
630629
)
631630
return True
632631

@@ -649,9 +648,9 @@ def _restore_from_flattened_values(self, checkpoint_data, target_model):
649648
# Restore trainable variables
650649
if "model_weights" in checkpoint_data:
651650
saved_trainable_values = checkpoint_data["model_weights"]
652-
target_trainable_structure = (
653-
target_state_tree["trainable_variables"]
654-
)
651+
target_trainable_structure = target_state_tree[
652+
"trainable_variables"
653+
]
655654
reconstructed_state["trainable_variables"] = (
656655
_reconstruct_state_tree_with_values(
657656
target_trainable_structure, saved_trainable_values
@@ -665,9 +664,9 @@ def _restore_from_flattened_values(self, checkpoint_data, target_model):
665664
and "optimizer_variables" in target_state_tree
666665
):
667666
saved_optimizer_values = checkpoint_data["optimizer_state"]
668-
target_optimizer_structure = (
669-
target_state_tree["optimizer_variables"]
670-
)
667+
target_optimizer_structure = target_state_tree[
668+
"optimizer_variables"
669+
]
671670
reconstructed_state["optimizer_variables"] = (
672671
_reconstruct_state_tree_with_values(
673672
target_optimizer_structure, saved_optimizer_values
@@ -702,5 +701,3 @@ def _restore_from_state_tree(self, state_tree, target_model):
702701
if self.verbose > 0:
703702
print_msg("OrbaxCheckpoint: Successfully restored model state")
704703
return True
705-
706-

keras/src/callbacks/orbax_checkpoint_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def test_checkpoint_error_handling(self):
359359
with self.assertRaises(Exception):
360360
callback.load_checkpoint(step=999)
361361

362-
# Test: Try to load latest when no checkpoints exist - should raise FileNotFoundError
362+
# Test: Try to load latest when no checkpoints exist -
363+
# should raise FileNotFoundError
363364
with self.assertRaises(FileNotFoundError):
364365
callback.load_latest()
365366

0 commit comments

Comments
 (0)