Skip to content

Commit 0f28c81

Browse files
committed
feat: add Checkpointer class and usage
Signed-off-by: Charlie Doern <[email protected]>
1 parent e94f8ab commit 0f28c81

File tree

3 files changed

+407
-263
lines changed

3 files changed

+407
-263
lines changed
Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
# Standard
2+
from copy import deepcopy
3+
from pathlib import Path
4+
import shutil
5+
import time
6+
import warnings
7+
8+
# Third Party
9+
from instructlab.dolomite.hf_models import export_to_huggingface
10+
from torch import distributed as dist
11+
from torch.distributed.fsdp import FullStateDictConfig
12+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13+
from torch.distributed.fsdp import StateDictType
14+
import torch
15+
16+
# First Party
17+
from instructlab.training.accelerator import Accelerator
18+
from instructlab.training.config import DistributedBackend
19+
from instructlab.training.model import Model
20+
21+
# Local
22+
from .utils import log_rank_0, wraps
23+
24+
25+
class Checkpointer:
26+
def __init__(
27+
self,
28+
model: Model,
29+
optimizer: torch.optim.Optimizer,
30+
accelerator: Accelerator,
31+
strategy="all",
32+
):
33+
self.strategy = strategy.lower()
34+
self.model = model
35+
self.optimizer = optimizer
36+
self.accelerator = accelerator
37+
38+
# Map strategies to internal methods
39+
self._checkpoint_fn = {
40+
"full_state": self.save_full_state,
41+
"hf_format": self.save_hf_format_accelerate,
42+
"all": self.save_all_checkpoints,
43+
}.get(self.strategy, self._no_checkpoint)
44+
45+
def checkpoint(self, *args, **kwargs):
46+
# Calls the method chosen at init
47+
return self._checkpoint_fn(*args, **kwargs)
48+
49+
# pylint: disable=unused-argument
50+
def _no_checkpoint(self, *args, **kwargs):
51+
print("[None] Skipping checkpointing.")
52+
53+
# pylint: disable=unused-argument
54+
def save_fsdp_lora_model(
55+
self,
56+
output_dir: Path,
57+
**kwargs,
58+
):
59+
"""Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original
60+
model with the trained LoRA adapters merged into the copy.
61+
62+
This function creates a full copy of the model being trained and stores it in CPU memory.
63+
If encountering OOM errors on CPU, this is likely a culprit.
64+
65+
Args:
66+
args (Namespace): Args received by the ArgumentParser.
67+
model (FSDP): FSDP model as prepared by `accelerate.Accelerator`
68+
accelerator (Accelerator): The given accelerator object.
69+
"""
70+
# Third Party
71+
from peft import LoraModel
72+
73+
if self.accelerator.distributed_type != DistributedBackend.FSDP:
74+
raise RuntimeError(
75+
"`save_fsdp_lora_model` was called when FSDP was not being used."
76+
)
77+
if not wraps(self.model, FSDP):
78+
raise RuntimeError(
79+
"`save_fsdp_lora_model` was called but provided model is not an FSDP model."
80+
)
81+
if not wraps(self.model, LoraModel):
82+
raise RuntimeError(
83+
"`save_fsdp_lora_model` was called but provided model is not a LoRA model."
84+
)
85+
86+
# okay now that validation is out of the way, we are free to implement saving
87+
sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
88+
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, sd_config):
89+
state = self.model.state_dict()
90+
91+
# When training a LoRA with FSDP and Accelerate, you cannot directly merge the adapters into
92+
# the model wrapped by FSDP. To get around this limitation, we get a copy of the state dict
93+
# create an identical model on CPU, load the state dict into the CPU model, merge the adapters
94+
# and save the model to disk.
95+
if self.accelerator.is_main_process:
96+
# Third Party
97+
from transformers import AutoModelForCausalLM
98+
99+
# remove device_map from args list so we can load the model on CPU
100+
old_device_map = self.model.base_model_args.pop("device_map", None)
101+
model_copy = AutoModelForCausalLM.from_pretrained(
102+
**self.model.base_model_args, device_map="cpu"
103+
)
104+
model_copy = LoraModel(model_copy, self.model.lora_config, "default")
105+
model_copy.load_state_dict(state)
106+
model_copy.merge_and_unload(progressbar=True)
107+
model_copy.save_pretrained(output_dir, safe_serialization=True)
108+
self.model.config.to_json_file(f"{output_dir}/config.json")
109+
self.model.tokenizer.save_pretrained(output_dir)
110+
del model_copy
111+
if old_device_map:
112+
# return the previous device_map so it can be used later on if needed
113+
self.model.base_model_args["device_map"] = old_device_map
114+
115+
dist.barrier()
116+
117+
# pylint: disable=unused-argument
118+
def save_full_state(
119+
self,
120+
output_dir,
121+
epoch: int,
122+
samples_seen: int,
123+
**kwargs,
124+
):
125+
"""
126+
Saves model, optimizer, and lr_scheduler state.
127+
TODO: save model config - decided not to do this.
128+
TODO: save tokenizer - decided not to do this.
129+
TODO: handle LoRA
130+
TODO: handle granite
131+
"""
132+
if self.model.lora_config is not None:
133+
raise NotImplementedError("Can't save full state for LoRA at the moment.")
134+
135+
# if args.is_granite:
136+
# raise NotImplementedError("Can't save full state for Granite models yet.")
137+
138+
output_dir = Path(output_dir) / "full_state" / f"epoch_{epoch}"
139+
log_rank_0(
140+
f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True
141+
)
142+
143+
# patch FSDP state dict method so it works correctly.
144+
def _get_state_dict_patched(model, unwrap=False):
145+
return get_state_dict_unpatched(model, unwrap=unwrap)
146+
147+
if self.accelerator.distributed_framework == "fsdp":
148+
get_state_dict_unpatched = self.accelerator.get_state_dict
149+
self.accelerator.get_state_dict = _get_state_dict_patched
150+
151+
self.accelerator.save_state(
152+
output_dir=output_dir,
153+
# max_shard_size="5GB",
154+
# safe_serialization=True,
155+
)
156+
157+
# save metadata file for current training status
158+
if self.accelerator.is_main_process:
159+
# TODO: should we set the global_step here rather than calculating global_step
160+
# based on samples_seen?
161+
metadata = {"current_epoch": epoch, "samples_seen": samples_seen}
162+
torch.save(metadata, output_dir / "training_metadata.json")
163+
log_rank_0(
164+
f"\033[93mSaving training state: {metadata}\033[0m", to_print=True
165+
)
166+
167+
log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True)
168+
169+
# cleanup
170+
if self.accelerator.distributed_framework == "fsdp":
171+
self.accelerator.get_state_dict = get_state_dict_unpatched
172+
173+
# pylint: disable=unused-argument
174+
def save_hf_format_accelerate(
175+
self,
176+
output_dir,
177+
epoch: int,
178+
samples_seen: int,
179+
last_epoch: bool = False,
180+
**kwargs,
181+
):
182+
# Standard
183+
from tempfile import TemporaryDirectory
184+
185+
# Build the subdirectory name
186+
subdir = "last_epoch" if last_epoch else f"samples_{samples_seen}"
187+
188+
log_rank_0(
189+
f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
190+
to_print=True,
191+
)
192+
start = time.time()
193+
194+
if self.model.model_type in ("gpt_megatron", "gpt_dolomite"):
195+
convert_dolomite = False
196+
else:
197+
convert_dolomite = True
198+
199+
# Build the final output directory path
200+
final_output_dir = Path(output_dir) / "hf_format" / subdir
201+
202+
if self.model.model_type == "dolomite" and convert_dolomite:
203+
tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
204+
output_dir = Path(tmpdir.name)
205+
else:
206+
output_dir = final_output_dir
207+
208+
CONFIG_NAME = "config.json"
209+
output_config_file = output_dir / CONFIG_NAME
210+
211+
# XXX(osilkin): LoRA + FSDP requires a different saving path than the others
212+
# so we set this variable and use it to avoid those paths further down.
213+
is_fsdp_lora = (
214+
self.model.lora_config is not None
215+
and self.accelerator.distributed_type == DistributedBackend.FSDP
216+
)
217+
if is_fsdp_lora:
218+
self.save_fsdp_lora_model(
219+
model=self.model,
220+
accelerator=self.accelerator,
221+
output_dir=output_dir,
222+
)
223+
224+
get_state_dict_unpatched = self.accelerator.get_state_dict
225+
226+
def _get_state_dict_patched(model, unwrap=False):
227+
return get_state_dict_unpatched(model, unwrap=unwrap)
228+
229+
self.accelerator.get_state_dict = _get_state_dict_patched
230+
231+
if not is_fsdp_lora and self.accelerator.is_main_process:
232+
if self.model.lora_config is not None:
233+
self.model.module.merge_adapter()
234+
model_state = self.model.module.state_dict()
235+
236+
output_dir.mkdir(parents=True, exist_ok=True)
237+
if not self.model.module.config.architectures and convert_dolomite:
238+
arch_added = False
239+
if self.model.model_type == "llama":
240+
self.model.module.config.architectures = ["LlamaForCausalLM"]
241+
arch_added = True
242+
elif self.model.model_type == "granite":
243+
self.model.module.config.architectures = ["GraniteForCausalLM"]
244+
arch_added = True
245+
if arch_added:
246+
warnings.warn(
247+
f"Adding architectures to ckpt: {self.model.module.config.architectures}",
248+
)
249+
else:
250+
warnings.warn(
251+
f"Converting from dolomite, but no architecture field added to config.json",
252+
)
253+
self.model.module.config.to_json_file(output_config_file)
254+
self.model.tokenizer.save_pretrained(output_dir)
255+
256+
if self.model.lora_config is not None:
257+
self.save_dict_accelerate(
258+
self.accelerator,
259+
model_state,
260+
save_directory=output_dir,
261+
max_shard_size="5GB",
262+
safe_serialization=True,
263+
)
264+
self.model.module.unmerge_adapter()
265+
266+
if self.model.lora_config is None:
267+
self.accelerator.save_model(
268+
self.model,
269+
save_directory=output_dir,
270+
max_shard_size="5GB",
271+
safe_serialization=True,
272+
)
273+
274+
if (
275+
self.model.model_type == "dolomite"
276+
and convert_dolomite
277+
and self.accelerator.is_main_process
278+
):
279+
# export doesnt like the directory to exist
280+
if final_output_dir.exists():
281+
shutil.rmtree(final_output_dir)
282+
export_to_huggingface(
283+
pretrained_model_name_or_path=tmpdir.name,
284+
save_path=final_output_dir,
285+
model_type=self.model.model_type,
286+
)
287+
tmpdir.cleanup()
288+
289+
log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True)
290+
log_rank_0(f"saving took {time.time() - start} seconds")
291+
dist.barrier()
292+
293+
self.accelerator.get_state_dict = get_state_dict_unpatched
294+
295+
def save_dict_accelerate(
296+
self,
297+
accelerator: Accelerator,
298+
state_to_save,
299+
save_directory,
300+
max_shard_size="5GB",
301+
safe_serialization=True,
302+
):
303+
old_get_state = accelerator.get_state_dict
304+
accelerator.get_state_dict = self._copy_no_lora_dict
305+
306+
def skip_precheck_loops():
307+
return []
308+
309+
# The save model does a loop over modules and params in order to determine how to get state dict. Since we already have the state dict directly, we want to bypass those checks.
310+
state_to_save.modules = skip_precheck_loops
311+
state_to_save.parameters = skip_precheck_loops
312+
313+
accelerator.save_model(
314+
state_to_save,
315+
save_directory=save_directory,
316+
max_shard_size=max_shard_size,
317+
safe_serialization=safe_serialization,
318+
)
319+
320+
accelerator.get_state_dict = old_get_state
321+
322+
def _copy_no_lora_dict(self, state_dict):
323+
# Standard
324+
from collections import OrderedDict
325+
326+
cleaned_state_dict = OrderedDict()
327+
for param_tensor in state_dict:
328+
if not "lora" in param_tensor:
329+
cleaned_state_dict[
330+
param_tensor.replace(".base_layer", "").replace(
331+
"basemodel.model.", ""
332+
)
333+
] = deepcopy(state_dict[param_tensor]).cpu()
334+
return cleaned_state_dict
335+
336+
def load_latest_full_state(self, output_dir: Path) -> None:
337+
"""Loads accelerator state from most recently saved checkpoint
338+
in `output_dir/full_state`.
339+
340+
Args:
341+
output_dir: Base output directory containing the full_state subdirectory
342+
"""
343+
full_state_dir = output_dir / "full_state"
344+
345+
if not full_state_dir.is_dir():
346+
return
347+
348+
# picks checkpoint with the largest number of samples by splitting the "samples_NNNN" string on _
349+
# and comparing the number at the end of the string
350+
checkpoint_list = sorted(
351+
list(full_state_dir.iterdir()),
352+
reverse=True,
353+
key=lambda x: int(str(x).rsplit("_", maxsplit=1)[-1]),
354+
)
355+
356+
if len(checkpoint_list) == 0:
357+
log_rank_0(
358+
f"\033[93mNo checkpoints to load from: {full_state_dir}\033[0m",
359+
to_print=True,
360+
)
361+
return
362+
363+
latest_checkpoint = checkpoint_list[0]
364+
log_rank_0(
365+
f"\033[93mLoading checkpoint from: {latest_checkpoint}\033[0m",
366+
to_print=True,
367+
)
368+
self.accelerator.load_state(latest_checkpoint)
369+
370+
def save_all_checkpoints(
371+
self,
372+
output_dir,
373+
epoch: int,
374+
samples_seen: int,
375+
last_epoch: bool = False,
376+
):
377+
self.save_hf_format_accelerate(
378+
output_dir=output_dir,
379+
epoch=epoch,
380+
samples_seen=samples_seen,
381+
last_epoch=last_epoch,
382+
)
383+
self.save_full_state(
384+
output_dir=output_dir, epoch=epoch, samples_seen=samples_seen
385+
)

0 commit comments

Comments
 (0)