Skip to content

Commit

Permalink
Merge lora gpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
wheresmyhair committed Apr 30, 2024
1 parent 1ae1e1c commit c6bcee3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
6 changes: 6 additions & 0 deletions examples/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class MergeLoraArguments:
"help": "output merged full model path"
},
)
local_rank: Optional[int] = field(
default=-1,
metadata={
"help": "local rank for deepspeed",
},
)


def main():
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_merge_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ elif [ ${device} == "gpu" ]; then
--lora_model_path ${lora_model_path} \
--output_model_path ${output_model_path} \
--device ${device} \
--ds_config configs/ds_config_eval.json
--ds_config configs/ds_config_zero3_for_eval.json
else
echo "error: unknown device \"${device}\"" 1>&2
exit 1
Expand Down
9 changes: 8 additions & 1 deletion src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
logger.debug(f"torch_dtype on init: {torch_dtype}")

config_kwargs = {
"cache_dir": model_args.cache_dir,
Expand Down Expand Up @@ -852,7 +853,13 @@ def save(self, dir, save_full_model=False, *args, **kwargs):
"""
self.get_tokenizer().save_pretrained(dir)
if save_full_model and self.model_args.use_lora:
self.backend_model_full.save_pretrained(dir)
save_dtype = (
torch.float16
if self.model_args.torch_dtype in ["auto", None]
else getattr(torch, self.model_args.torch_dtype)
)
self.backend_model_full.to(dtype=save_dtype).save_pretrained(dir)
logger.warning(f"Save full model with dtype: {save_dtype}")
else:
self.get_backend_model().save_pretrained(dir)

Expand Down

0 comments on commit c6bcee3

Please sign in to comment.