-
Notifications
You must be signed in to change notification settings - Fork 515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update QuantizationRecipe to use checkpointer.save_checkpoint #2257
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,8 @@ | |
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import os | ||
import sys | ||
import time | ||
from pathlib import Path | ||
from typing import Any, Dict | ||
|
||
import torch | ||
|
@@ -53,6 +51,11 @@ def __init__(self, cfg: DictConfig) -> None: | |
training.set_seed(seed=cfg.seed) | ||
|
||
def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: | ||
logger.info( | ||
"Setting safe_serialization to False. TorchAO quantization is compatible " | ||
"only with HuggingFace's non-safetensor serialization and deserialization." | ||
) | ||
checkpointer_cfg.safe_serialization = False | ||
self._checkpointer = config.instantiate(checkpointer_cfg) | ||
checkpoint_dict = self._checkpointer.load_checkpoint() | ||
return checkpoint_dict | ||
|
@@ -95,21 +98,8 @@ def quantize(self, cfg: DictConfig): | |
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") | ||
|
||
def save_checkpoint(self, cfg: DictConfig): | ||
ckpt_dict = self._model.state_dict() | ||
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0] | ||
|
||
output_dir = Path(cfg.checkpointer.output_dir) | ||
output_dir.mkdir(exist_ok=True) | ||
checkpoint_file = Path.joinpath( | ||
output_dir, f"{file_name}-{self._quantization_mode}".rstrip("-qat") | ||
).with_suffix(".pt") | ||
|
||
torch.save(ckpt_dict, checkpoint_file) | ||
logger.info( | ||
"Model checkpoint of size " | ||
f"{os.path.getsize(checkpoint_file) / 1024**3:.2f} GiB " | ||
f"saved to {checkpoint_file}" | ||
) | ||
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()} | ||
self._checkpointer.save_checkpoint(ckpt_dict, epoch=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does clean the code quite a bit but now the files are saved with an awkward name in an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally agreed, the file name is a bit awkward. However, since |
||
|
||
|
||
@config.parse | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only an argument for the HF checkpointer and not the others, so the instantiate would fail for other checkpointer classes. Maybe you could check for the attribute after it's instantiate and then set to False if it's present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a very good point. I will make required changes.