diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index 81229ddd2a..7f2afc8f19 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -131,15 +131,15 @@ def main( fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) - check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") - checkpoint_path = checkpoint_dir / "lit_model.pth" # Merge if this is a raw LoRA checkpoint - if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file(): + if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file(): print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.") - merge_lora(checkpoint_path) + merge_lora(checkpoint_dir) + + check_valid_checkpoint_dir(checkpoint_dir) + config = Config.from_file(checkpoint_dir / "model_config.yaml") with fabric.init_module(empty_init=True): model = GPT(config) diff --git a/tests/test_chat.py b/tests/test_chat.py index 4ac44aff56..3f456e7421 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import os import re import subprocess import sys @@ -15,6 +16,8 @@ import litgpt.chat.base as chat import litgpt.generate.base as generate +from litgpt import Config +from litgpt.utils import save_config @pytest.mark.parametrize( @@ -129,3 +132,26 @@ def test_cli(mode): output = subprocess.check_output(args) output = str(output.decode()) assert "Starts a conversation" in output + + +@patch("litgpt.chat.base.input") +@patch("litgpt.chat.base.merge_lora") +def test_merge_lora_if_needed(mocked_merge_lora, mocked_input, fake_checkpoint_dir, monkeypatch, tensor_like): + # these values will be iteratively provided for each `input()` call + mocked_input.side_effect = [""] + + # pretend there is an unmerged LORA checkpoint + os.rename(fake_checkpoint_dir / "lit_model.pth", fake_checkpoint_dir / "lit_model.pth.lora") + mocked_merge_lora.side_effect = lambda _: Path(fake_checkpoint_dir / "lit_model.pth").touch() + + config = Config.from_name("pythia-14m") + save_config(config, fake_checkpoint_dir) + monkeypatch.setattr(chat, "load_checkpoint", Mock()) + monkeypatch.setattr(chat, "Tokenizer", Mock()) + + out, err = StringIO(), StringIO() + with redirect_stdout(out), redirect_stderr(err): + chat.main(checkpoint_dir=fake_checkpoint_dir) + + assert re.match("Merging LoRA weights with the base model.", out.getvalue(), re.DOTALL) + mocked_merge_lora.assert_called_once()