From e757badf68c4aa5033c89ecae59cdddbb80f0852 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jul 2024 08:18:56 -0400 Subject: [PATCH] need to set both split_batches and dispatch_batches to false for pretraining --- .../utils/config/models/input/v0_4_1/__init__.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index aa698cf24..f05c091e9 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -77,6 +77,7 @@ class PretrainingDataset(BaseModel): split: Optional[str] = "train" text_column: Optional[str] = "text" type: Optional[str] = "pretrain" + trust_remote_code: Optional[bool] = False class UserDefinedPrompterType(BaseModel): @@ -118,6 +119,8 @@ class SFTDataset(BaseModel): roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None + trust_remote_code: Optional[bool] = False + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" @@ -158,6 +161,7 @@ class KTODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None + trust_remote_code: Optional[bool] = False class RLType(str, Enum): @@ -711,9 +715,15 @@ def check_pretraining_split_batches_accelerate(cls, data): if data.get("pretraining_dataset"): accelerator_config = data.get("accelerator_config", {}) if not accelerator_config: - data["accelerator_config"] = {"split_batches": True} - elif accelerator_config.get("split_batches") is None: - data["accelerator_config"]["split_batches"] = True + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False return data @model_validator(mode="before")