Skip to content

Commit

Permalink
Add back meta-device assign=True loading in merge_lora (Lightning-A…
Browse files Browse the repository at this point in the history
…I#1250)

Co-authored-by: Sebastian Raschka <[email protected]>
  • Loading branch information
carmocca and rasbt authored May 7, 2024
1 parent 90a16e4 commit c0d1dd0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ checkpoints
out
wandb
events.out.tfevents*

# test artifacts from tests/test_readme.py
tests/custom_finetuning_dataset.json
tests/custom_texts
13 changes: 10 additions & 3 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,23 @@ def merge_lora(
fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu")
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)

with fabric.init_module():
with fabric.init_module(), torch.device("meta"):
model = GPT(config)
# we don't care about these to perform merging
model.cos = None
model.sin = None

lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True)
lora_checkpoint = torch.load(str(lora_path), mmap=True)
lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint)

# Merge LoRA weights into the base model
pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(pretrained_checkpoint)
pretrained_checkpoint.update(lora_checkpoint)
model.load_state_dict(pretrained_checkpoint, assign=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
model.to(dtype=lora_dtype, device="cpu")
merge_lora_weights(model)

# Remove LoRA parameters and the LoRA linear substring
Expand Down

0 comments on commit c0d1dd0

Please sign in to comment.