Skip to content

Commit

Permalink
fix checkpoint load issue (NVIDIA#11859)
Browse files Browse the repository at this point in the history
* fix checkpoint load issue

Signed-off-by: Dmytro Pykhtar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <[email protected]>

* set weights_only to False

Signed-off-by: dimapihtar <[email protected]>

---------

Signed-off-by: Dmytro Pykhtar <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Co-authored-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar and dimapihtar authored Jan 17, 2025
1 parent 8786345 commit 0cd990d
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 18 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def init(self) -> GPTModel:

def apply(self, output_path: Path) -> Path:

source = torch.load(str(self), map_location='cpu')
source = torch.load(str(self), map_location='cpu', weights_only=False)
if 'model' in source:
source = source['model']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True):
else:
print("Loading unet blocks from sd")

state_dict = torch.load(from_pretrained_unet, map_location='cpu')
state_dict = torch.load(from_pretrained_unet, map_location='cpu', weights_only=False)
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
model_state_dict = self.state_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _load_model(model_ckpt: str, model_cfg: str, eval_mode: bool = True, trainer
model_cfg.model.micro_batch_size = 1
model_cfg.model.global_batch_size = 1
model = MegatronImagen(cfg=model_cfg.model, trainer=trainer)
checkpoint = torch.load(model_ckpt, map_location=lambda storage, loc: storage)
checkpoint = torch.load(model_ckpt, map_location=lambda storage, loc: storage, weights_only=False)

# Change weight keys if training using TorchInductor
state_dict = checkpoint['state_dict']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@

class LatentDiffusionEdit(LatentDiffusion):
def init_from_ckpt(
self, path, ignore_keys=list(), only_model=False, load_vae=True, load_unet=True, load_encoder=True,
self,
path,
ignore_keys=list(),
only_model=False,
load_vae=True,
load_unet=True,
load_encoder=True,
):
pl_sd = torch.load(path, map_location="cpu")
pl_sd = torch.load(path, map_location="cpu", weights_only=False)
if "state_dict" in list(pl_sd.keys()):
pl_sd = pl_sd["state_dict"]
sd = {}
Expand Down Expand Up @@ -144,7 +150,7 @@ def model_provider_func(self, pre_process=True, post_process=True):
return model

def setup(self, stage=None):
""" PTL hook that is executed after DDP spawns.
"""PTL hook that is executed after DDP spawns.
We setup datasets here as megatron datasets require DDP to instantiate.
See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information.
Args:
Expand Down Expand Up @@ -260,5 +266,8 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, drop_last=Tru

# Torch dataloader.
return torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, num_workers=self._cfg.data.num_workers, pin_memory=True,
dataset,
batch_sampler=batch_sampler,
num_workers=self._cfg.data.num_workers,
pin_memory=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def ema_scope(self, context=None):
print(f"{context}: Restored training weights")

def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
Expand Down Expand Up @@ -345,7 +345,7 @@ def __init__(

state_dict = load_safetensors(from_pretrained)
else:
state_dict = torch.load(from_pretrained)
state_dict = torch.load(from_pretrained, weights_only=False)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo)
Expand Down Expand Up @@ -476,7 +476,7 @@ def load(module: torch.nn.Module, prefix=""):
return error_msgs

def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def init_from_ckpt(
load_unet=True,
load_encoder=True,
):
pl_sd = torch.load(path, map_location="cpu")
pl_sd = torch.load(path, map_location="cpu", weights_only=False)
if "state_dict" in list(pl_sd.keys()):
pl_sd = pl_sd["state_dict"]

Expand Down Expand Up @@ -2340,7 +2340,7 @@ def _modify_state_dict(state_dict):
if filepath.endswith('.nemo'):
conf, state_dict = self._get_config_and_state_dict_from_nemo(filepath, map_location)
elif filepath.endswith('.ckpt'):
state_dict = torch.load(filepath, map_location)['state_dict']
state_dict = torch.load(filepath, map_location, weights_only=False)['state_dict']
else:
raise RuntimeError(f"{filepath} is not nemo file or ckpt file")
if not peft_cfgs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def __init__(

state_dict = load_safetensors(from_pretrained)
else:
state_dict = torch.load(from_pretrained, map_location='cpu')
state_dict = torch.load(from_pretrained, map_location='cpu', weights_only=False)
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
missing_key, unexpected_keys, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def load_adapters_for_inference(cls, cfg: DictConfig, model_cfg: DictConfig, mod
peft_cfg_cls = PEFT_CONFIG_MAP[model_cfg.peft.peft_scheme]
model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg), map_location="cpu")
else:
torch_state_dict = torch.load(cfg.model.peft.restore_from_path)['state_dict']
torch_state_dict = torch.load(cfg.model.peft.restore_from_path, weights_only=False)['state_dict']
model.load_state_dict(torch_state_dict, strict=False)
elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name:
checkpoint_path = os.path.join(
Expand All @@ -1096,7 +1096,7 @@ def load_adapters_for_inference(cls, cfg: DictConfig, model_cfg: DictConfig, mod
peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]
model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls(model_cfg), map_location="cpu")
else:
model.load_state_dict(torch.load(checkpoint_path), strict=False)
model.load_state_dict(torch.load(checkpoint_path, weights_only=False), strict=False)
else:
raise NotImplementedError("distributed checkpointing of PEFT weights is not supported")
elif model_cfg.peft.get("peft_scheme", None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_adapters(
if filepath.endswith('.nemo'):
conf, state_dict = self._get_config_and_state_dict_from_nemo(filepath, map_location)
elif filepath.endswith('.ckpt'):
state_dict = torch.load(filepath, map_location)['state_dict']
state_dict = torch.load(filepath, map_location, weights_only=False)['state_dict']
else:
raise RuntimeError(f"{filepath} is not nemo file or ckpt file")
if not peft_cfgs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def load_adapters(
sharded_state_dict = self.sharded_state_dict(prefix="model.")
conf, state_dict = self._get_config_and_state_dict_from_nemo(filepath, map_location, sharded_state_dict)
elif filepath.endswith('.ckpt'):
state_dict = torch.load(filepath, map_location)['state_dict']
state_dict = torch.load(filepath, map_location, weights_only=False)['state_dict']
else:
raise RuntimeError(f"{filepath} is not nemo file or ckpt file")
if not self.ptuning_only_and_non_first_stage:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def load_adapters(
if filepath.endswith('.nemo'):
conf, state_dict = self._get_config_and_state_dict_from_nemo(filepath, map_location)
elif filepath.endswith('.ckpt'):
state_dict = torch.load(filepath, map_location)['state_dict']
state_dict = torch.load(filepath, map_location, weights_only=False)['state_dict']
else:
raise RuntimeError(f"{filepath} is not nemo file or ckpt file")
if not peft_cfgs:
Expand Down

0 comments on commit 0cd990d

Please sign in to comment.