forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpoint_utils.py
244 lines (216 loc) · 8.79 KB
/
checkpoint_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""Utilities for checkpointing.
Note: Code adapted from
https://github.com/google/init2winit/blob/master/init2winit/checkpoint.py.
"""
import os
from typing import Sequence, Tuple
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
import jax
import numpy as np
from tensorflow.io import gfile # pytype: disable=import-error
import torch
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
_, _, DEVICE, _ = pytorch_setup()
CheckpointReturn = Tuple[spec.OptimizerState,
spec.ParameterContainer,
spec.ModelAuxiliaryState,
dict,
list,
int,
int]
def maybe_restore_checkpoint(framework: str,
optimizer_state: spec.OptimizerState,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
train_state: dict,
eval_results: list,
global_step: int,
preemption_count: int,
checkpoint_dir: str) -> CheckpointReturn:
"""Optionally restores from a checkpoint.
The checkpoint logic is as follows: if there is a checkpoint in
`checkpoint_dir`, restore it. Else, don't restore any checkpoint, and
just return the passed-in optimizer_state, model_params,
model_state, and train_state.
Args:
framework: Current framework (e.g., `jax` or `pytorch`).
optimizer_state: Optimizer state.
model_params: Model parameters.
model_state: Model state such as batch statistics when batch
normalization is used.
train_state: Training state such as `last_eval_time`.
eval_results: Previous evaluation results.
global_step: Global step.
preemption_count: Number of preemptions.
checkpoint_dir: The training directory where we will look for a checkpoint.
Returns:
A tuple of (optimizer_state, model_params, model_state,
train_state, eval_results, global_step, preemption_count).
"""
if framework == 'jax':
opt_state, opt_update_fn = optimizer_state
else:
opt_state, opt_update_fn = optimizer_state, None
uninitialized_global_step = -1
uninitialized_preemption_count = -1
checkpoint_state = {
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': None,
'global_step': uninitialized_global_step,
'preemption_count': uninitialized_preemption_count,
}
if framework == 'jax':
latest_ckpt = flax_checkpoints.restore_checkpoint(
checkpoint_dir, target=checkpoint_state)
save_path = os.path.join(checkpoint_dir,
'checkpoint_' + str(latest_ckpt['global_step']))
else:
latest_ckpt = checkpoint_state
save_path = latest_checkpoint(checkpoint_dir)
if save_path is not None:
latest_ckpt = torch.load(save_path, map_location=DEVICE)
# Load_latest_checkpoint() will return checkpoint_state if
# checkpoint_dir does not exist or if it exists and contains no checkpoints.
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step
if not found_checkpoint:
return (optimizer_state,
model_params,
model_state,
train_state,
eval_results,
global_step,
preemption_count)
# If there's the latest checkpoint in the checkpoint_dir, restore from that.
if framework == 'jax':
checkpoint_state = replicate_checkpoint(
latest_ckpt,
pytree_keys=[
'optimizer_state',
'model_params',
'model_state',
])
checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'],
opt_update_fn)
checkpoint_state['eval_results'] = [
(value, key) for key, value in latest_ckpt['eval_results'].items()
]
else:
checkpoint_state = latest_ckpt
if isinstance(
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
model_params = model_params.module
model_params.load_state_dict(checkpoint_state['model_params'])
checkpoint_state['model_params'] = model_params
for key in optimizer_state.keys():
optimizer_state[key].load_state_dict(
checkpoint_state['optimizer_state'][key])
checkpoint_state['optimizer_state'][key] = optimizer_state[key]
logging.info(f'Loaded checkpoint from {save_path}.')
return (checkpoint_state['optimizer_state'],
checkpoint_state['model_params'],
checkpoint_state['model_state'],
checkpoint_state['train_state'],
list(checkpoint_state['eval_results']),
checkpoint_state['global_step'],
checkpoint_state['preemption_count'] + 1)
def replicate_checkpoint(latest: dict,
pytree_keys: Sequence[str],
replicate: bool = True) -> dict:
"""Restores from the provided checkpoint.
Args:
latest: A dict representing the state of the
checkpoint we want to restore.
pytree_keys: A sequence of keys into `latest` that are pytrees, which will
be replicated if replicate=True.
replicate: If set, replicate the state across devices.
Returns:
A JAX pytree holding the arrays that need to be replicated/unreplicated.
"""
pytree = {k: latest[k] for k in pytree_keys}
if replicate:
pytree = jax_utils.replicate(pytree)
extra_dict = {k: latest[k] for k in latest.keys() if k not in pytree_keys}
pytree.update(extra_dict)
return pytree
def save_checkpoint(framework: str,
optimizer_state: spec.OptimizerState,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
train_state: dict,
eval_results: list,
global_step: int,
preemption_count: int,
checkpoint_dir: str,
save_intermediate_checkpoints: bool) -> None:
"""Save the checkpoint in `checkpoint_dir`.
Args:
framework: Current framework (e.g., `jax` or `pytorch`).
optimizer_state: Optimizer state.
model_params: Model parameters.
model_state: Model state such as batch statistics when batch
normalization is used.
train_state: Training state such as `last_eval_time`.
eval_results: Previous evaluation results.
global_step: Global step.
preemption_count: Number of preemptions.
checkpoint_dir: The training directory where we will look for a checkpoint.
save_intermediate_checkpoints: Whether to save intermediate checkpoints.
Returns:
A tuple of (optimizer_state, model_params, model_state,
train_state, eval_results, global_step, preemption_count).
"""
if framework == 'jax':
model_params = jax.device_get(jax_utils.unreplicate(model_params))
opt_state, _ = optimizer_state
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
model_params = model_params.module
model_params = model_params.state_dict()
optimizer_state_dict = {}
for key in optimizer_state.keys():
if hasattr(optimizer_state[key], 'state_dict'):
optimizer_state_dict[key] = optimizer_state[key].state_dict()
else:
logging.warning(
f'The optimizer state for key {key} is not saved, because '
f'{type(optimizer_state[key])} has not implemented a state_dict() '
'method.')
opt_state = optimizer_state_dict
checkpoint_state = {
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': tuple(eval_results),
'global_step': global_step,
'preemption_count': preemption_count,
}
save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}')
if framework == 'jax':
flax_checkpoints.save_checkpoint(
checkpoint_dir,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.Inf if save_intermediate_checkpoints else 1)
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
os.path.join(checkpoint_dir, 'checkpoint_*'))
for path in checkpoint_files:
logging.info('Removing checkpoint at %s', path)
gfile.rmtree(path)
torch.save(checkpoint_state, save_path)
logging.info(f'Saved checkpoint to {save_path}.')