-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update QuantizationRecipe to use checkpointer.save_checkpoint #2257
base: main
Are you sure you want to change the base?
Update QuantizationRecipe to use checkpointer.save_checkpoint #2257
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2257
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b42649e with merge base e79ab8b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@felipemello1 There are two caveats that I think are important to highlight
|
hey @Ankur-singh , thanks for the PR! I just came back from PTO. I will get to this PR this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of comments, otherwise looks good
recipes/quantize.py
Outdated
"Setting safe_serialization to False. TorchAO quantization is compatible " | ||
"only with HuggingFace's non-safetensor serialization and deserialization." | ||
) | ||
checkpointer_cfg.safe_serialization = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only an argument for the HF checkpointer and not the others, so the instantiate would fail for other checkpointer classes. Maybe you could check for the attribute after it's instantiate and then set to False if it's present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a very good point. I will make required changes.
f"saved to {checkpoint_file}" | ||
) | ||
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()} | ||
self._checkpointer.save_checkpoint(ckpt_dict, epoch=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does clean the code quite a bit but now the files are saved with an awkward name in an epoch-0
folder. I suppose that is alright for now since this is meant to be an example, and we don't have a way of changing the checkpoint subfolder name. But maybe at some point in the future it would be nice to control.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agreed, the file name is a bit awkward. However, since epoch
is a required argument, we can't get around it. We should circle back to this sometime in future.
hey @Ankur-singh , just before we merge, were you able to run the recipe and load the weights / run inference as part of testing? Maybe just chat with it a little to confirm it is loaded and works? |
Hello @felipemello1, I see the model is quantized but I'm unable to read it with
I will check torchao documentation and try to make it work. Do you have any suggestions? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocking merging by accident until we test the model ckpt
@ebsmothers @joecummings whats the best way to test the quantized model for inference? Does it work with our FullModelHFCheckpointer? |
@felipemello1 Sorry for the delayed response. Over the weekend, I did some digging and found that quantization is only supported by
So, did the following to test the PR:
Here is the code used for step 1: from torchtune.models.qwen2_5 import qwen2_5_0_5b
from torchtune.training.checkpointing._checkpointer import (
FullModelTorchTuneCheckpointer,
FullModelHFCheckpointer,
)
hf_ckptr = FullModelHFCheckpointer(
checkpoint_dir="/tmp/Qwen2_5-0_5B-Instruct/",
checkpoint_files=["model.safetensors",],
model_type="QWEN2",
output_dir="./",
)
ckpt_dict = hf_ckptr.load_checkpoint() # in torch tune format
# Don't worry about the arguments much, we would only be using `save_checkpoint` method.
# `save_checkpoint` method only makes use of `checkpoint_dir` and `output_dir`.
tt_ckptr = FullModelTorchTuneCheckpointer(
checkpoint_dir="/tmp/Qwen2_5-0_5B-Instruct/",
checkpoint_files=["model.safetensors",],
model_type="QWEN2",
output_dir="./",
)
tt_ckptr.save_checkpoint(state_dict=ckpt_dict, epoch=0)
# ------------------------
# Model checkpoint of size 0.92 GiB saved to epoch_0/ft-model-00001-of-00001.bin
# Saving final epoch checkpoint.
# The full model checkpoint, including all weights and configurations, has been saved successfully. You can now use this checkpoint for further training or inference. With the quantized model, I'm able to get similar score as unquantized model. Here is the output from both: Unquantized Model: (tune) ➜ torchtune git:(refactor/quantization-recipe-save-checkpoint) ✗ tune run eleuther_eval --config cust_eval.yaml
Running EleutherEvalRecipe with resolved config:
batch_size: 8
checkpointer:
_component_: torchtune.training.FullModelTorchTuneCheckpointer
checkpoint_dir: ./epoch_0
checkpoint_files:
- ft-model-00001-of-00001.bin
model_type: QWEN2
output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: null
max_seq_length: 4096
model:
_component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
groupsize: 256
seed: 1234
tasks:
- truthfulqa_mc2
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
max_seq_len: null
merges_file: ./epoch_0/merges.txt
path: ./epoch_0/vocab.json
2025-02-03:16:15:25,127 INFO [_utils.py:28] Running EleutherEvalRecipe with resolved config:
batch_size: 8
checkpointer:
_component_: torchtune.training.FullModelTorchTuneCheckpointer
checkpoint_dir: ./epoch_0
checkpoint_files:
- ft-model-00001-of-00001.bin
model_type: QWEN2
output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: null
max_seq_length: 4096
model:
_component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
groupsize: 256
seed: 1234
tasks:
- truthfulqa_mc2
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
max_seq_len: null
merges_file: ./epoch_0/merges.txt
path: ./epoch_0/vocab.json
Model is initialized with precision torch.bfloat16.
2025-02-03:16:15:26,030 INFO [eleuther_eval.py:503] Model is initialized with precision torch.bfloat16.
2025-02-03:16:15:26,274 INFO [huggingface.py:132] Using device 'cuda:0'
2025-02-03:16:15:27,230 INFO [huggingface.py:369] Model parallel was set to False, max memory was not set, and device map was set to {'': 'cuda:0'}
Running evaluation on the following tasks: ['truthfulqa_mc2']
2025-02-03:16:15:36,445 INFO [eleuther_eval.py:540] Running evaluation on the following tasks: ['truthfulqa_mc2']
2025-02-03:16:15:36,446 INFO [task.py:415] Building contexts for truthfulqa_mc2 on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 817/817 [00:00<00:00, 1652.22it/s]
2025-02-03:16:15:36,971 INFO [evaluator.py:496] Running loglikelihood requests
Running loglikelihood requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5882/5882 [05:20<00:00, 18.37it/s]
Eval completed in 322.81 seconds.
2025-02-03:16:20:59,259 INFO [eleuther_eval.py:549] Eval completed in 322.81 seconds.
Max memory allocated: 8.68 GB
2025-02-03:16:20:59,260 INFO [eleuther_eval.py:550] Max memory allocated: 8.68 GB
| Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2| 2|none | 0|acc |↑ |0.4178|± |0.0146|
2025-02-03:16:20:59,337 INFO [eleuther_eval.py:554]
| Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2| 2|none | 0|acc |↑ |0.4178|± |0.0146|
Quantized Model:
You can see a clear difference in Eval speed (322.81 sec vs 84.17 secs) and Max Memory allocated (8.68 GB vs 5.53 GB). Also, here is the output from the
|
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses #2229
Changelog
What are the changes made in this PR?
save_checkpoint
method by making use ofcheckpointer.save_checkpoint
instead of custom implementation.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
Quantization recipe:
Output:
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example