Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update QuantizationRecipe to use checkpointer.save_checkpoint #2257

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Ankur-singh
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (refactor)

Please link to any issues this PR addresses #2229

Changelog

What are the changes made in this PR?

  • Simplified save_checkpoint method by making use of checkpointer.save_checkpoint instead of custom implementation.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Quantization recipe:

# Config for QuantizationRecipe in quantize.py
#
# To launch, run the following command from root torchtune directory:
#    tune download Qwen/Qwen2.5-0.5B-Instruct --output-dir /tmp/Qwen2_5-0_5B-Instruct
#    tune run quantize --config quantization

output_dir: ./quantized # /tmp may be deleted by your system. Change it to your preference.

#
# Model arguments
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2_5-0_5B-Instruct
  checkpoint_files: [model.safetensors]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: QWEN2

device: cuda
dtype: bf16
seed: 1234

quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256

Output:

(tune) ➜  torchtune git:(refactor/quantization-recipe-save-checkpoint) ✗ tune run quantize --config ./custom_quant.yaml
Running QuantizationRecipe with resolved config:

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2_5-0_5B-Instruct
  checkpoint_files:
  - model.safetensors
  model_type: QWEN2
  output_dir: ./quantized
  recipe_checkpoint: null
device: cuda
dtype: bf16
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./quantized
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234

Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
Model is initialized with precision torch.bfloat16.
Time for quantization: 0.08 sec
Memory used: 1.45 GB
Model checkpoint of size 0.82 GiB saved to quantized/epoch_0/ft-model-00001-of-00001.bin
Saving final epoch checkpoint.
The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jan 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2257

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b42649e with merge base e79ab8b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 13, 2025
@Ankur-singh
Copy link
Contributor Author

@felipemello1 There are two caveats that I think are important to highlight

  • epoch is set to 0 as it is a required parameter. Is there a better and elegant way?

  • safe_serialization is disabled because

    torchao quantization is implemented with tensor subclasses, it only work with huggingface non-safetensor serialization and deserialization

    source

@felipemello1
Copy link
Contributor

hey @Ankur-singh , thanks for the PR! I just came back from PTO. I will get to this PR this week.

@felipemello1 felipemello1 self-requested a review January 13, 2025 19:37
@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of comments, otherwise looks good

"Setting safe_serialization to False. TorchAO quantization is compatible "
"only with HuggingFace's non-safetensor serialization and deserialization."
)
checkpointer_cfg.safe_serialization = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only an argument for the HF checkpointer and not the others, so the instantiate would fail for other checkpointer classes. Maybe you could check for the attribute after it's instantiate and then set to False if it's present?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a very good point. I will make required changes.

f"saved to {checkpoint_file}"
)
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}
self._checkpointer.save_checkpoint(ckpt_dict, epoch=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does clean the code quite a bit but now the files are saved with an awkward name in an epoch-0 folder. I suppose that is alright for now since this is meant to be an example, and we don't have a way of changing the checkpoint subfolder name. But maybe at some point in the future it would be nice to control.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agreed, the file name is a bit awkward. However, since epoch is a required argument, we can't get around it. We should circle back to this sometime in future.

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 27, 2025

hey @Ankur-singh , just before we merge, were you able to run the recipe and load the weights / run inference as part of testing? Maybe just chat with it a little to confirm it is loaded and works?

@Ankur-singh
Copy link
Contributor Author

Hello @felipemello1, I see the model is quantized but I'm unable to read it with FullModelHFCheckpointer. Getting the following error:

While copying the parameter named "layers.23.mlp.w2.weight", whose dimensions in the model are torch.Size([896, 4864]) and whose dimensions in the checkpoint are torch.Size([896, 4864]), an exception occurred : ("LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.copy_', overload='default')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torch.nn.parameter.Parameter'>, <class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>), kwarg_types={}",).

I will check torchao documentation and try to make it work. Do you have any suggestions?

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking merging by accident until we test the model ckpt

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 29, 2025

@ebsmothers @joecummings whats the best way to test the quantized model for inference? Does it work with our FullModelHFCheckpointer?

@Ankur-singh
Copy link
Contributor Author

Ankur-singh commented Feb 4, 2025

@felipemello1 Sorry for the delayed response.

Over the weekend, I did some digging and found that quantization is only supported by FullModelTorchTuneCheckpointer. This error message from EleutherEval Recipe sums it pretty well. Pasting below for reference.

ValueError: Quantization is only supported for models quantized and saved with the FullModelTorchTuneCheckpointer - please ensure you have quantized your model and are using the quantized weights!

So, did the following to test the PR:

  1. First load the model using FullModelHFCheckpointer and then save it using FullModelTorchTuneCheckpointer.
  2. Quantize the model from step 1
  3. Run eval recipe with quantized model form step 2

Here is the code used for step 1:

from torchtune.models.qwen2_5 import qwen2_5_0_5b
from torchtune.training.checkpointing._checkpointer import (
    FullModelTorchTuneCheckpointer,
    FullModelHFCheckpointer,
)

hf_ckptr = FullModelHFCheckpointer(
    checkpoint_dir="/tmp/Qwen2_5-0_5B-Instruct/",
    checkpoint_files=["model.safetensors",],
    model_type="QWEN2",
    output_dir="./",
)

ckpt_dict = hf_ckptr.load_checkpoint()  # in torch tune format

# Don't worry about the arguments much, we would only be using `save_checkpoint` method.
# `save_checkpoint` method only makes use of `checkpoint_dir` and `output_dir`. 
tt_ckptr = FullModelTorchTuneCheckpointer(
    checkpoint_dir="/tmp/Qwen2_5-0_5B-Instruct/",
    checkpoint_files=["model.safetensors",],
    model_type="QWEN2",
    output_dir="./",
)

tt_ckptr.save_checkpoint(state_dict=ckpt_dict, epoch=0)

# ------------------------
# Model checkpoint of size 0.92 GiB saved to epoch_0/ft-model-00001-of-00001.bin
# Saving final epoch checkpoint.
# The full model checkpoint, including all weights and configurations, has been saved successfully. You can now use this checkpoint for further training or inference.

With the quantized model, I'm able to get similar score as unquantized model. Here is the output from both:

Unquantized Model:

(tune) ➜  torchtune git:(refactor/quantization-recipe-save-checkpoint) ✗ tune run eleuther_eval --config cust_eval.yaml
Running EleutherEvalRecipe with resolved config:

batch_size: 8
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: ./epoch_0
  checkpoint_files:
  - ft-model-00001-of-00001.bin
  model_type: QWEN2
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: null
max_seq_length: 4096
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
tasks:
- truthfulqa_mc2
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  max_seq_len: null
  merges_file: ./epoch_0/merges.txt
  path: ./epoch_0/vocab.json

2025-02-03:16:15:25,127 INFO     [_utils.py:28] Running EleutherEvalRecipe with resolved config:

batch_size: 8
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: ./epoch_0
  checkpoint_files:
  - ft-model-00001-of-00001.bin
  model_type: QWEN2
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: null
max_seq_length: 4096
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
tasks:
- truthfulqa_mc2
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  max_seq_len: null
  merges_file: ./epoch_0/merges.txt
  path: ./epoch_0/vocab.json

Model is initialized with precision torch.bfloat16.
2025-02-03:16:15:26,030 INFO     [eleuther_eval.py:503] Model is initialized with precision torch.bfloat16.
2025-02-03:16:15:26,274 INFO     [huggingface.py:132] Using device 'cuda:0'
2025-02-03:16:15:27,230 INFO     [huggingface.py:369] Model parallel was set to False, max memory was not set, and device map was set to {'': 'cuda:0'}
Running evaluation on the following tasks: ['truthfulqa_mc2']
2025-02-03:16:15:36,445 INFO     [eleuther_eval.py:540] Running evaluation on the following tasks: ['truthfulqa_mc2']
2025-02-03:16:15:36,446 INFO     [task.py:415] Building contexts for truthfulqa_mc2 on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 817/817 [00:00<00:00, 1652.22it/s]
2025-02-03:16:15:36,971 INFO     [evaluator.py:496] Running loglikelihood requests
Running loglikelihood requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5882/5882 [05:20<00:00, 18.37it/s]
Eval completed in 322.81 seconds.
2025-02-03:16:20:59,259 INFO     [eleuther_eval.py:549] Eval completed in 322.81 seconds.
Max memory allocated: 8.68 GB
2025-02-03:16:20:59,260 INFO     [eleuther_eval.py:550] Max memory allocated: 8.68 GB


|    Tasks     |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   ||0.4178|±  |0.0146|


2025-02-03:16:20:59,337 INFO     [eleuther_eval.py:554] 

|    Tasks     |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   ||0.4178|±  |0.0146|

Quantized Model:

(tune) ➜  torchtune git:(refactor/quantization-recipe-save-checkpoint) ✗ tune run eleuther_eval --config cust_eval.yaml
Running EleutherEvalRecipe with resolved config:

batch_size: 8
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: ./quantized/epoch_0
  checkpoint_files:
  - ft-model-00001-of-00001.bin
  model_type: QWEN2
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: null
max_seq_length: 4096
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
tasks:
- truthfulqa_mc2
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  max_seq_len: null
  merges_file: ./quantized/epoch_0/merges.txt
  path: ./quantized/epoch_0/vocab.json

2025-02-03:16:34:48,389 INFO     [_utils.py:28] Running EleutherEvalRecipe with resolved config:

batch_size: 8
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: ./quantized/epoch_0
  checkpoint_files:
  - ft-model-00001-of-00001.bin
  model_type: QWEN2
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: null
max_seq_length: 4096
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
tasks:
- truthfulqa_mc2
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  max_seq_len: null
  merges_file: ./quantized/epoch_0/merges.txt
  path: ./quantized/epoch_0/vocab.json

Model is initialized with precision torch.bfloat16.
2025-02-03:16:34:48,953 INFO     [eleuther_eval.py:503] Model is initialized with precision torch.bfloat16.
2025-02-03:16:34:49,133 INFO     [huggingface.py:132] Using device 'cuda:0'
2025-02-03:16:34:59,780 INFO     [huggingface.py:369] Model parallel was set to False, max memory was not set, and device map was set to {'': 'cuda:0'}
Running evaluation on the following tasks: ['truthfulqa_mc2']
2025-02-03:16:35:08,921 INFO     [eleuther_eval.py:540] Running evaluation on the following tasks: ['truthfulqa_mc2']
2025-02-03:16:35:08,922 INFO     [task.py:415] Building contexts for truthfulqa_mc2 on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 817/817 [00:00<00:00, 1606.56it/s]
2025-02-03:16:35:09,461 INFO     [evaluator.py:496] Running loglikelihood requests
Running loglikelihood requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5882/5882 [01:21<00:00, 72.27it/s]
Eval completed in 84.17 seconds.
2025-02-03:16:36:33,095 INFO     [eleuther_eval.py:549] Eval completed in 84.17 seconds.
Max memory allocated: 5.53 GB
2025-02-03:16:36:33,095 INFO     [eleuther_eval.py:550] Max memory allocated: 5.53 GB


|    Tasks     |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |↑  |0.4183|±  |0.0145|


2025-02-03:16:36:33,166 INFO     [eleuther_eval.py:554] 

|    Tasks     |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |↑  |0.4183|±  |0.0145|

You can see a clear difference in Eval speed (322.81 sec vs 84.17 secs) and Max Memory allocated (8.68 GB vs 5.53 GB).

Also, here is the output from the generate recipe:

(tune) ➜  torchtune git:(refactor/quantization-recipe-save-checkpoint) ✗ tune run generate --config cust_gen.yaml
Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: ./quantized/epoch_0
  checkpoint_files:
  - ft-model-00001-of-00001.bin
  model_type: QWEN2
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
max_new_tokens: 300
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
output_dir: ./
prompt:
  system: null
  user: Tell me a joke.
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  max_seq_len: null
  merges_file: ./quantized/epoch_0/merges.txt
  path: ./quantized/epoch_0/vocab.json
top_k: 300

Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
Model is initialized with precision torch.bfloat16.
Starting compilation to improve generation performance ...
W0203 16:46:36.742000 748300 site-packages/torch/_inductor/utils.py:1048] [0/0] Not enough SMs to use max_autotune_gemm mode
AUTOTUNE addmm(1x896, 1x896, 896x896)
  bias_addmm 0.0113 ms 100.0% 
  addmm 0.0133 ms 84.6% 
SingleProcess AUTOTUNE benchmarking takes 0.4790 seconds and 0.0000 seconds precompiling
AUTOTUNE addmm(1x128, 1x896, 896x128)
  bias_addmm 0.0073 ms 100.0% 
  addmm 0.0092 ms 78.8% 
SingleProcess AUTOTUNE benchmarking takes 0.2306 seconds and 0.0000 seconds precompiling
AUTOTUNE addmm(1x896, 1x896, 896x896)
  bias_addmm 0.0102 ms 100.0% 
  addmm 0.0133 ms 76.9% 
SingleProcess AUTOTUNE benchmarking takes 0.2261 seconds and 0.0000 seconds precompiling
Warmup run for quantized model takes: 52.60 sec
<|im_start|>user
Tell me a joke.<|im_end|>
<|im_start|>assistant
<|endoftext|>Human: What do you call a person who likes broccoli?

A contestant in a game show.<|im_end|>
Time for inference: 22.84 sec total, 0.88 tokens/sec
Bandwidth achieved: 1.04 GB/s
Memory used: 1.99 GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants