Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Sep 17, 2024
2 parents 84573fa + 229e230 commit 6871548
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 18 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/labelpr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Label PRs

on:
pull_request_target:
types: [opened, edited, synchronize, reopened]

jobs:
label_pr:
runs-on: ubuntu-latest
steps:
- uses: actions/github-script@v3
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const pr_welcome_msg = `Thanks for making a pull request! 😃\nOne of the maintainers will review and advise on the next steps.`;
// https://github.com/commitizen/conventional-commit-types
const valid_pr_types = ['feat', 'fix', 'docs', 'style', 'refactor', 'perf', 'test', 'build', 'ci', 'chore', 'revert'];
if(context.payload.pull_request.comments === 0) {
await github.issues.createComment({ ...context.repo, issue_number: context.payload.number, body: pr_welcome_msg});
}
const title = context.payload.pull_request.title;
const results = /^(\w+)(\(\w+\))?!?:/.exec(title);
if (results === null) return core.setFailed(`The title does not follow conventional commits spec: https://www.conventionalcommits.org/en/v1.0.0/#summary Title: ${title}`);
const pr_type = results[1];
core.info(`pr_type: ${pr_type}`);
if (!valid_pr_types.includes(pr_type)) return core.setFailed(`Unknown pull request type: ${pr_type}`);
const labels = context.payload.pull_request.labels;
const new_labels = labels.filter(label => !valid_pr_types.includes(label.name)); // keep all labels that are not in valid_pr_types
new_labels.push({name: pr_type});
await github.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ indent-string=' '
max-line-length=100

# Maximum number of lines in a module.
max-module-lines=1100
max-module-lines=1200

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ You can set `output_dir` to a local directory and set `save_model_dir` to COS to

In order to achieve the fastest train time, set `save_strategy="no"`, as saving no checkpoints except for the final model will remove intermediate write operations all together.

#### Resuming tuning from checkpoints
If the output directory already contains checkpoints, tuning will automatically resume from the latest checkpoint in the directory specified by the `output_dir` flag. To start tuning from scratch and ignore existing checkpoints, set the `resume_from_checkpoint` flag to False.

You can also use the resume_from_checkpoint flag to resume tuning from a specific checkpoint by providing the full path to the desired checkpoint as a string. This flag is passed as an argument to the [trainer.train()](https://github.com/huggingface/transformers/blob/db70426854fe7850f2c5834d633aff637f14772e/src/transformers/trainer.py#L1901) function of the SFTTrainer.

## Tuning Techniques:

### LoRA Tuning Example
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "sci
flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
fms-accel = ["fms-acceleration>=0.1"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]


[tool.setuptools.packages.find]
Expand Down
73 changes: 58 additions & 15 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Third Party
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
import torch

# Local
Expand Down Expand Up @@ -176,20 +176,45 @@ def load(
else {}
)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
# Apply the configs to the adapter config of this model; if no overrides
# are provided, then the context manager doesn't have any effect.
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
if base_model_name_or_path is None:
raise ValueError("base_model_name_or_path has to be passed")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

if (
has_quantized_config(base_model_name_or_path)
and device == "cuda"
):
# Using GPTQConfig from HF, avail params are here
# https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.GPTQConfig
# We only support 4-bit AutoGPTQ, so setting bits to 4
# setting exllama kernel to version 2 as it's a faster kernel
gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})

# Since we are using exllama kernel, we need torch.float16 as torch_dtype
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
device_map=device,
torch_dtype=torch.float16,
quantization_config=gptq_config,
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
Expand All @@ -211,14 +236,28 @@ def load(
except FileNotFoundError:
print("No adapter config found! Loading as a merged model...")
# Unable to find the adapter config; fall back to loading as a merged model
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
if has_quantized_config(checkpoint_path) and device == "cuda":
# Using GPTQConfig from HF, avail params are here
# https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.GPTQConfig
# We only support 4-bit AutoGPTQ, so setting bits to 4
# setting exllama kernel to version 2 as it's a faster kernel
gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})

# Since we are using exllama kernel, we need torch.float16 as torch_dtype
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
device_map=device,
torch_dtype=torch.float16,
quantization_config=gptq_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
model.to(device)
return cls(model, tokenizer, device)

Expand Down Expand Up @@ -327,5 +366,9 @@ def main():
print(f"Exported results to: {args.out_file}")


def has_quantized_config(model_path: str):
return os.path.exists(os.path.join(model_path, "quantize_config.json"))


if __name__ == "__main__":
main()
208 changes: 208 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,214 @@
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)


def test_resume_training_from_checkpoint():
"""
Test tuning resumes from the latest checkpoint, creating new checkpoints and the
checkpoints created before resuming tuning is not affected.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert final_trainer_state is not None

assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5
assert final_trainer_state["global_step"] > init_trainer_state["global_step"]

# Check if loss of 1st epoch after first tuning is same after
# resuming tuning and not overwritten
assert len(init_trainer_state["log_history"]) > 0

init_log_history = init_trainer_state["log_history"][0]
assert init_log_history["epoch"] == 1

final_log_history = final_trainer_state["log_history"][0]
assert final_log_history["epoch"] == 1

assert init_log_history["loss"] == final_log_history["loss"]


def test_resume_training_from_checkpoint_with_flag_true():
"""
Test tuning resumes from the latest checkpoint when flag is true,
creating new checkpoints and the checkpoints created before resuming
tuning is not affected.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.resume_from_checkpoint = "True"

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Get Training logs
init_training_logs = _get_training_logs_by_epoch(tempdir)

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert final_trainer_state is not None

assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5
assert final_trainer_state["global_step"] > init_trainer_state["global_step"]

final_training_logs = _get_training_logs_by_epoch(tempdir)

assert (
init_training_logs[0]["data"]["timestamp"]
== final_training_logs[0]["data"]["timestamp"]
)


def test_resume_training_from_checkpoint_with_flag_false():
"""
Test when setting resume_from_checkpoint=False that tuning will start from scratch.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.resume_from_checkpoint = "False"

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Get Training log entry for epoch 1
init_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1)
assert len(init_training_logs) == 1

# Training again with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get Training log entry for epoch 1
final_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1)
assert len(final_training_logs) == 2


def test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora():
"""
Test resume checkpoint from a specified checkpoint path for LoRA tuning.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
lora_config = copy.deepcopy(PEFT_LORA_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config)
_validate_training(tempdir)

# Get trainer state and checkpoint_path of second last checkpoint
init_trainer_state, checkpoint_path = _get_latest_checkpoint_trainer_state(
tempdir, checkpoint_index=-2
)
assert init_trainer_state is not None

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
train_args.resume_from_checkpoint = checkpoint_path
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config)
_validate_training(tempdir)

# Get total_flos from trainer state of checkpoint_path and check if its same
final_trainer_state = None
trainer_state_file = os.path.join(checkpoint_path, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
final_trainer_state = json.load(f)

assert final_trainer_state["total_flos"] == init_trainer_state["total_flos"]


def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = -1):
"""
Get the trainer state from the latest or specified checkpoint directory.
The trainer state is returned along with the path to the checkpoint.
Args:
dir_path (str): The directory path where checkpoint folders are located.
checkpoint_index (int, optional): The index of the checkpoint to retrieve,
based on the checkpoint number. The default
is -1, which returns the latest checkpoint.
Returns:
trainer_state: The trainer state loaded from `trainer_state.json` in the
checkpoint directory.
last_checkpoint: The path to the checkpoint directory.
"""
trainer_state = None
last_checkpoint = None
checkpoints = [
os.path.join(dir_path, d)
for d in os.listdir(dir_path)
if d.startswith("checkpoint")
]
if checkpoints:
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[
checkpoint_index
]
trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
trainer_state = json.load(f)
return trainer_state, last_checkpoint


def _get_training_logs_by_epoch(dir_path: str, epoch: int = None):
"""
Load and optionally filter training_logs.jsonl file.
If an epoch number is specified, the function filters the logs
and returns only the entries corresponding to the specified epoch.
Args:
dir_path (str): The directory path where the `training_logs.jsonl` file is located.
epoch (int, optional): The epoch number to filter logs by. If not specified,
all logs are returned.
Returns:
list: A list containing the training logs. If `epoch` is specified,
only logs from the specified epoch are returned; otherwise, all logs are returned.
"""
data_list = []
with open(f"{dir_path}/training_logs.jsonl", "r", encoding="utf-8") as file:
for line in file:
json_data = json.loads(line)
data_list.append(json_data)

if epoch:
mod_data_list = []
for value in data_list:
if value["data"]["epoch"] == epoch:
mod_data_list.append(value)
return mod_data_list
return data_list


def test_run_train_requires_output_dir():
"""Check fails when output dir not provided."""
updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS)
Expand Down
Loading

0 comments on commit 6871548

Please sign in to comment.