From 0df8e7a20255e556a2c82c4be7992b6e744a5127 Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Thu, 23 May 2024 09:55:51 -0600 Subject: [PATCH 1/2] Refactor tests explicit params (#163) * refactor tests for explicit param passing Signed-off-by: Sukriti-Sharma4 * refacor more tests Signed-off-by: Sukriti-Sharma4 * merge main and cleanup Signed-off-by: Sukriti-Sharma4 * remove test helpers Signed-off-by: Sukriti-Sharma4 * fix formatting Signed-off-by: Sukriti-Sharma4 * update eval_strategy flag Signed-off-by: Anh-Uong * use empty data in test Signed-off-by: Anh-Uong --------- Signed-off-by: Sukriti-Sharma4 Signed-off-by: Anh-Uong Co-authored-by: Anh-Uong --- tests/helpers.py | 45 ----- tests/test_sft_trainer.py | 342 ++++++++++++++------------------------ 2 files changed, 129 insertions(+), 258 deletions(-) delete mode 100644 tests/helpers.py diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index 59695826f..000000000 --- a/tests/helpers.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Third Party -import transformers - -# Local -from tuning.config import configs, peft_config - - -def causal_lm_train_kwargs(train_kwargs): - """Parse the kwargs for a valid train call to a Causal LM.""" - parser = transformers.HfArgumentParser( - dataclass_types=( - configs.ModelArguments, - configs.DataArguments, - configs.TrainingArguments, - peft_config.LoraConfig, - peft_config.PromptTuningConfig, - ) - ) - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_dict(train_kwargs, allow_extra_keys=True) - tuning_config = None - if train_kwargs.get("peft_method") == "lora": - tuning_config = lora_config - elif train_kwargs.get("peft_method") == "pt": - tuning_config = prompt_tuning_config - return (model_args, data_args, training_args, tuning_config) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index a55f7d25b..bbe91f890 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -30,103 +30,60 @@ # First Party from scripts.run_inference import TunedCausalLM from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA -from tests.helpers import causal_lm_train_kwargs # Local from tuning import sft_trainer -from tuning.config import peft_config +from tuning.config import configs, peft_config MODEL_NAME = "Maykeye/TinyLLama-v0" -BASE_PEFT_KWARGS = { - "model_name_or_path": MODEL_NAME, - "training_data_path": TWITTER_COMPLAINTS_DATA, - "num_train_epochs": 5, - "per_device_train_batch_size": 4, - "per_device_eval_batch_size": 4, - "gradient_accumulation_steps": 4, - "learning_rate": 0.00001, - "weight_decay": 0, - "warmup_ratio": 0.03, - "lr_scheduler_type": "cosine", - "logging_steps": 1, - "include_tokens_per_second": True, - "packing": False, - "response_template": "\n### Label:", - "dataset_text_field": "output", - "use_flash_attn": False, - "torch_dtype": "float32", - "max_seq_length": 4096, - "peft_method": "pt", - "prompt_tuning_init": "RANDOM", - "num_virtual_tokens": 8, - "prompt_tuning_init_text": "hello", - "tokenizer_name_or_path": MODEL_NAME, - "save_strategy": "epoch", - "output_dir": "tmp", -} - -BASE_LORA_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS) -BASE_LORA_KWARGS["peft_method"] = "lora" - -BASE_FT_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS) -BASE_FT_KWARGS["peft_method"] = None -del BASE_FT_KWARGS["prompt_tuning_init"] -del BASE_FT_KWARGS["prompt_tuning_init_text"] - - -def test_helper_causal_lm_train_kwargs(): - """Check happy path kwargs passed and parsed properly.""" - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - BASE_PEFT_KWARGS - ) - - assert model_args.model_name_or_path == MODEL_NAME - assert model_args.use_flash_attn is False - assert model_args.torch_dtype == "float32" - - assert data_args.training_data_path == TWITTER_COMPLAINTS_DATA - assert data_args.response_template == "\n### Label:" - assert data_args.dataset_text_field == "output" - - assert training_args.num_train_epochs == 5 - assert training_args.max_seq_length == 4096 - assert training_args.save_strategy == "epoch" - - assert tune_config.prompt_tuning_init == "RANDOM" - assert tune_config.prompt_tuning_init_text == "hello" - assert tune_config.tokenizer_name_or_path == MODEL_NAME - assert tune_config.num_virtual_tokens == 8 +MODEL_ARGS = configs.ModelArguments( + model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" +) +DATA_ARGS = configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + response_template="\n### Label:", + dataset_text_field="output", +) +TRAIN_ARGS = configs.TrainingArguments( + num_train_epochs=5, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + gradient_accumulation_steps=4, + learning_rate=0.00001, + weight_decay=0, + warmup_ratio=0.03, + lr_scheduler_type="cosine", + logging_steps=1, + include_tokens_per_second=True, + packing=False, + max_seq_length=4096, + save_strategy="epoch", + output_dir="tmp", +) +PEFT_PT_ARGS = peft_config.PromptTuningConfig( + prompt_tuning_init="RANDOM", + num_virtual_tokens=8, + prompt_tuning_init_text="hello", + tokenizer_name_or_path=MODEL_NAME, +) - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - BASE_FT_KWARGS - ) - assert tune_config is None - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - BASE_LORA_KWARGS - ) - assert isinstance(tune_config, peft_config.LoraConfig) +PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05) def test_run_train_requires_output_dir(): """Check fails when output dir not provided.""" - updated_output_dir = copy.deepcopy(BASE_PEFT_KWARGS) - updated_output_dir["output_dir"] = None - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - updated_output_dir - ) + updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS) + updated_output_dir_train_args.output_dir = None with pytest.raises(TypeError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, DATA_ARGS, updated_output_dir_train_args, None) def test_run_train_fails_training_data_path_not_exist(): """Check fails when data path not found.""" - updated_output_path = copy.deepcopy(BASE_PEFT_KWARGS) - updated_output_path["training_data_path"] = "fake/path" - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - updated_output_path - ) + updated_data_path_args = copy.deepcopy(DATA_ARGS) + updated_data_path_args.training_data_path = "fake/path" with pytest.raises(FileNotFoundError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, updated_data_path_args, TRAIN_ARGS, None) ############################# Prompt Tuning Tests ############################# @@ -135,18 +92,16 @@ def test_run_train_fails_training_data_path_not_exist(): def test_run_causallm_pt_and_inference(): """Check if we can bootstrap and peft tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}} + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, PEFT_PT_ARGS) # validate peft tuning configs _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - _validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_PEFT_KWARGS) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) # Load the model loaded_model = TunedCausalLM.load(checkpoint_path) @@ -162,20 +117,22 @@ def test_run_causallm_pt_and_inference(): def test_run_causallm_pt_init_text(): """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"output_dir": tempdir, "prompt_tuning_init": "TEXT"}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + tuning_config = peft_config.PromptTuningConfig( + prompt_tuning_init="TEXT", + prompt_tuning_init_text="hello", + tokenizer_name_or_path=MODEL_NAME, ) - sft_trainer.train(model_args, data_args, training_args, tune_config) + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, tuning_config) # validate peft tuning configs _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - _validate_adapter_config(adapter_config, "PROMPT_TUNING", TRAIN_KWARGS) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", tuning_config) invalid_params_map = [ @@ -193,34 +150,27 @@ def test_run_causallm_pt_init_text(): invalid_params_map, ids=["num_train_epochs", "grad_acc_steps"], ) -def test_run_causallm_pt_invalid_params(param_name, param_val, exc_msg): +def test_run_causallm_pt_invalid_train_params(param_name, param_val, exc_msg): """Check if error is raised when invalid params are used to peft tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: - invalid_params = copy.deepcopy(BASE_PEFT_KWARGS) - invalid_params["output_dir"] = tempdir - invalid_params[param_name] = param_val - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - invalid_params - ) + invalid_params = copy.deepcopy(TRAIN_ARGS) + invalid_params.output_dir = tempdir + setattr(invalid_params, param_name, param_val) with pytest.raises(ValueError, match=exc_msg): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, DATA_ARGS, invalid_params, PEFT_PT_ARGS) def test_run_causallm_pt_with_validation(): """Check if we can bootstrap and peft tune causallm models with validation dataset""" with tempfile.TemporaryDirectory() as tempdir: - validation_peft = copy.deepcopy(BASE_PEFT_KWARGS) - validation_peft["output_dir"] = tempdir - validation_peft["validation_data_path"] = TWITTER_COMPLAINTS_DATA - validation_peft["evaluation_strategy"] = "epoch" - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - validation_peft - ) - - assert data_args.validation_data_path == TWITTER_COMPLAINTS_DATA + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.eval_strategy = "epoch" + data_args = copy.deepcopy(DATA_ARGS) + data_args.validation_data_path = TWITTER_COMPLAINTS_DATA - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) _validate_training(tempdir, check_eval=True) @@ -247,21 +197,19 @@ def test_run_causallm_pt_with_validation(): def test_run_causallm_lora_and_inference(request, target_modules, expected): """Check if we can bootstrap and lora tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: - base_lora_kwargs = copy.deepcopy(BASE_LORA_KWARGS) - base_lora_kwargs["output_dir"] = tempdir + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + base_lora_args = copy.deepcopy(PEFT_LORA_ARGS) if "default" not in request._pyfuncitem.callspec.id: - base_lora_kwargs["target_modules"] = target_modules + base_lora_args.target_modules = target_modules - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - base_lora_kwargs - ) - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, base_lora_args) # validate lora tuning configs _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - _validate_adapter_config(adapter_config, "LORA", base_lora_kwargs) + _validate_adapter_config(adapter_config, "LORA", base_lora_args) for module in expected: assert module in adapter_config.get("target_modules") @@ -283,14 +231,10 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): def test_run_causallm_ft_and_inference(): """Check if we can bootstrap and finetune tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: - BASE_FT_KWARGS["output_dir"] = tempdir - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - BASE_FT_KWARGS - ) - # Just assuring no tuning config is passed for PT or LoRA - assert tune_config is None + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) # validate ft tuning configs _validate_training(tempdir) @@ -332,13 +276,13 @@ def _get_adapter_config(dir_path): return json.load(f) -def _validate_adapter_config(adapter_config, peft_type, base_kwargs): +def _validate_adapter_config(adapter_config, peft_type, tuning_config): assert adapter_config.get("task_type") == "CAUSAL_LM" assert adapter_config.get("peft_type") == peft_type assert ( ( adapter_config.get("tokenizer_name_or_path") - == base_kwargs["tokenizer_name_or_path"] + == tuning_config.tokenizer_name_or_path ) if peft_type == "PROMPT_TUNING" else True @@ -358,118 +302,96 @@ def test_tokenizer_has_no_eos_token(): # This is a bit roundabout, but patch the tokenizer and export it and the model to a tempdir # that we can then reload out of for the train call, and clean up afterwards. tokenizer = transformers.AutoTokenizer.from_pretrained( - BASE_PEFT_KWARGS["model_name_or_path"] + MODEL_ARGS.model_name_or_path ) model = transformers.AutoModelForCausalLM.from_pretrained( - BASE_PEFT_KWARGS["model_name_or_path"] + MODEL_ARGS.model_name_or_path ) tokenizer.eos_token = None with tempfile.TemporaryDirectory() as tempdir: tokenizer.save_pretrained(tempdir) model.save_pretrained(tempdir) - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"model_name_or_path": tempdir, "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + model_args = copy.deepcopy(MODEL_ARGS) + train_args.model_name_or_path = tempdir + # If we handled this badly, we would probably get something like a # TypeError: can only concatenate str (not "NoneType") to str error # when we go to apply the data formatter. - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_PT_ARGS) _validate_training(tempdir) ### Tests for Bad dataset specification, i.e., data is valid, but the field we point it at isn't def test_invalid_dataset_text_field(): """Ensure that if we specify a dataset_text_field that doesn't exist, we get a KeyError.""" - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"dataset_text_field": "not found", "output_dir": "foo/bar/baz"}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = "not found" + with pytest.raises(KeyError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) ### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) def test_malformatted_data(): """Ensure that malformatted data explodes due to failure to generate the dataset.""" - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"training_data_path": MALFORMATTED_DATA, "output_dir": "foo/bar/baz"}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = MALFORMATTED_DATA + with pytest.raises(DatasetGenerationError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) def test_empty_data(): """Ensure that malformatted data explodes due to failure to generate the dataset.""" - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"training_data_path": EMPTY_DATA, "output_dir": "foo/bar/baz"}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = EMPTY_DATA + with pytest.raises(DatasetGenerationError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) def test_data_path_is_a_directory(): """Ensure that we get FileNotFoundError if we point the data path at a dir, not a file.""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"training_data_path": tempdir, "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = tempdir + # Confusingly, if we pass a directory for our data path, it will throw a # FileNotFoundError saying "unable to find ''", since it can't # find a matchable file in the path. with pytest.raises(FileNotFoundError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) ### Tests for bad tuning module configurations def test_run_causallm_lora_with_invalid_modules(): """Check that we throw a value error if the target modules for lora don't exist.""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"peft_method": "lora", "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir # Defaults are q_proj / v_proj; this will fail lora as the torch module doesn't have them - tune_config.target_modules = ["foo", "bar"] + lora_config = copy.deepcopy(PEFT_LORA_ARGS) + lora_config.target_modules = ["foo", "bar"] # Peft should throw a value error about modules not matching the base module with pytest.raises(ValueError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config) ### Direct validation tests based on whether or not packing is enabled def test_no_packing_needs_dataset_text_field(): """Ensure we need to set the dataset text field if packing is False""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"dataset_text_field": None, "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = None + with pytest.raises(ValueError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) # TODO: Fix this case @@ -477,15 +399,13 @@ def test_no_packing_needs_dataset_text_field(): def test_no_packing_needs_reponse_template(): """Ensure we need to set the response template if packing is False""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"response_template": None, "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.response_template = None + with pytest.raises(ValueError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) ### Tests for model dtype edge cases @@ -497,26 +417,22 @@ def test_bf16_still_tunes_if_unsupported(): """Ensure that even if bf16 is not supported, tuning still works without problems.""" assert not torch.cuda.is_bf16_supported() with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"torch_dtype": "bfloat16", "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) - sft_trainer.train(model_args, data_args, training_args, tune_config) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + model_args = copy.deepcopy(MODEL_ARGS) + model_args.torch_dtype = "bfloat16" + + sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_PT_ARGS) _validate_training(tempdir) def test_bad_torch_dtype(): """Ensure that specifying an invalid torch dtype yields a ValueError.""" with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_PEFT_KWARGS, - **{"torch_dtype": "not a type", "output_dir": tempdir}, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + model_args = copy.deepcopy(MODEL_ARGS) + model_args.torch_dtype = "not a type" + with pytest.raises(ValueError): - sft_trainer.train(model_args, data_args, training_args, tune_config) + sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_PT_ARGS) From 697b573d6a549a502d0311c51db92bb47ea771d8 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Thu, 23 May 2024 13:39:41 -0600 Subject: [PATCH 2/2] docs: update eval_strategy flag used in transformers (#168) Signed-off-by: Anh-Uong --- README.md | 8 ++++---- examples/prompt_tuning_twitter_complaints/README.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 887d61502..bb1a95876 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ python tuning/sft_trainer.py \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 4 \ ---evaluation_strategy "no" \ +--eval_strategy "no" \ --save_strategy "epoch" \ --learning_rate 1e-5 \ --weight_decay 0. \ @@ -125,7 +125,7 @@ tuning/sft_trainer.py \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 4 \ ---evaluation_strategy "no" \ +--eval_strategy "no" \ --save_strategy "epoch" \ --learning_rate 1e-5 \ --weight_decay 0. \ @@ -279,7 +279,7 @@ tuning/sft_trainer.py \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ ---evaluation_strategy "no" \ +--eval_strategy "no" \ --save_strategy "epoch" \ --learning_rate 1e-5 \ --weight_decay 0. \ @@ -313,7 +313,7 @@ tuning/sft_trainer.py \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ ---evaluation_strategy "no" \ +--eval_strategy "no" \ --save_strategy "epoch" \ --learning_rate 1e-5 \ --weight_decay 0. \ diff --git a/examples/prompt_tuning_twitter_complaints/README.md b/examples/prompt_tuning_twitter_complaints/README.md index c8383cd57..cd8e95233 100644 --- a/examples/prompt_tuning_twitter_complaints/README.md +++ b/examples/prompt_tuning_twitter_complaints/README.md @@ -51,7 +51,7 @@ tuning/sft_trainer.py \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ ---evaluation_strategy "no" \ +--eval_strategy "no" \ --save_strategy "epoch" \ --learning_rate 1e-5 \ --weight_decay 0. \