diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 767fea1b4..7b52e3af4 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -637,6 +637,9 @@ def main(): combined_tracker_configs.file_logger_config = file_logger_config combined_tracker_configs.aim_config = aim_config + if training_args.output_dir: + os.makedirs(training_args.output_dir, exist_ok=True) + logger.info("using the output directory at %s", training_args.output_dir) try: trainer, additional_train_info = train( model_args=model_args,