Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Dec 11, 2024
2 parents 52ce7ec + 4168c87 commit 1d1bba9
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 50 deletions.
9 changes: 6 additions & 3 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

### Constants used for data
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__))
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
)
PRETOKENIZE_JSON_DATA_YAML = os.path.join(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
)
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml"
)
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ datasets:
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
dataset_template: "dataset_template"
template: "dataset_template"
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
dataprocessor:
type: default
sampling_stopping_strategy: first_exhausted
seed: 66
datasets:
- name: dataset_1
sampling: 0.3
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
- name: dataset_2
sampling: 0.4
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
- name: dataset_3
sampling: 0.3
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ datasets:
remove_columns: all
batched: false
fn_kwargs:
input_field: "INPUT"
output_field: "OUTPUT"
input_field_name: "INPUT"
output_field_name: "OUTPUT"
129 changes: 116 additions & 13 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@

# First Party
from tests.artifacts.predefined_data_configs import (
APPLY_CUSTOM_TEMPLATE_YAML,
PRETOKENIZE_JSON_DATA_YAML,
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
)
from tests.artifacts.testdata import (
MODEL_NAME,
Expand Down Expand Up @@ -428,22 +429,22 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
@pytest.mark.parametrize(
"data_config_path, data_path",
[
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
),
],
Expand Down Expand Up @@ -709,3 +710,105 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
with open(datafile, "r") as file:
data = json.load(file)
assert len(train_dataset) == len(data)


@pytest.mark.parametrize(
"datafiles, sampling, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[0.3, None, 0.3],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[0.3, 0.5, 0.3],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataset_configs_with_sampling_error(
datafiles, sampling, datasetconfigname
):

data_args = configs.DataArguments()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for i in range(len(datasets)):
d = datasets[i]
d["data_paths"][0] = datafiles[i]
d["sampling"] = sampling[i]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

with pytest.raises(ValueError):
(_, _, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname):

data_args = configs.DataArguments()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for i in range(len(datasets)):
d = datasets[i]
d["data_paths"][0] = datafiles[i]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

(train_set, eval_set, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)

assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
if eval_set:
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))
39 changes: 27 additions & 12 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ class DataHandlerConfig:
class DataSetConfig:
name: str
data_paths: List[str]
sampling: Optional[Dict] = None
sampling: Optional[float] = None
data_handlers: Optional[List[DataHandlerConfig]] = None


@dataclass
class DataPreProcessorConfig:
type: Optional[str] = "default"
sampling_stopping_strategy: Optional[str] = "all_exhausted"
# Default seed is not none to ensure reproducability
sampling_seed: Optional[float] = 42


@dataclass
Expand Down Expand Up @@ -84,17 +87,12 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
)
p = _p
c.data_paths.append(p)
if "sampling" in kwargs:
sampling_kwargs = kwargs["sampling"]
assert isinstance(
dict, sampling_kwargs
), "sampling arguments should be of the type dict"
if "ratio" in sampling_kwargs:
ratio = sampling_kwargs["ratio"]
assert isinstance(ratio, float) and (
0 <= ratio <= 1.0
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
c.sampling = sampling_kwargs
if "sampling" in kwargs and kwargs["sampling"] is not None:
ratio = kwargs["sampling"]
assert isinstance(ratio, float) and (
0 <= ratio <= 1.0
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
c.sampling = ratio
if "data_handlers" in kwargs:
c.data_handlers = []
for handler in kwargs["data_handlers"]:
Expand All @@ -106,6 +104,23 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
kwargs = dataprocessor_config
c = DataPreProcessorConfig()
assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict"
if "type" in kwargs:
assert isinstance(kwargs["type"], str), "dataprocessor type must be a string"
c.type = kwargs["type"]
if "sampling_stopping_strategy" in kwargs:
strategy = kwargs["sampling_stopping_strategy"]
assert isinstance(
strategy, str
), "dataset sampling stopping strategy must be a string"
assert strategy in [
"first_exhausted",
"all_exhausted",
], "allowed sampling stopping strategies are all_exhausted(default) or first_exhausted"
c.sampling_stopping_strategy = strategy
if "sampling_seed" in kwargs:
seed = kwargs["sampling_seed"]
assert isinstance(seed, int), "sampling seed should be int"
c.sampling_seed = seed
return c


Expand Down
Loading

0 comments on commit 1d1bba9

Please sign in to comment.