Skip to content

Commit

Permalink
update to be deprecated evaluation_strategy and c4 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jul 5, 2024
1 parent b3f680d commit 61b8ecf
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 425 deletions.
16 changes: 7 additions & 9 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,17 +1187,15 @@ def build(self, total_num_steps):

if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
training_arguments_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
elif self.cfg.eval_strategy:
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
training_arguments_kwargs["eval_strategy"] = "epoch"

if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
Expand Down Expand Up @@ -1559,10 +1557,10 @@ def build_training_arguments(self, total_num_steps):
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors

if self.eval_dataset:
training_args_kwargs["evaluation_strategy"] = "steps"
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["evaluation_strategy"] = "no"
training_args_kwargs["eval_strategy"] = "no"

if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
Expand Down
5 changes: 1 addition & 4 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ def on_step_end(
control: TrainerControl,
**kwargs,
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
control.should_evaluate = True
return control

Expand Down
Loading

0 comments on commit 61b8ecf

Please sign in to comment.