Skip to content

Commit

Permalink
refactor utils.data module for line count linter (#1476)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Apr 4, 2024
1 parent c2b64e4 commit e0fcef4
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 330 deletions.
15 changes: 15 additions & 0 deletions src/axolotl/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Data processing modules
"""
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,
load_tokenized_prepared_datasets,
prepare_dataset,
)
from axolotl.utils.data.utils import md5 # noqa: F401
114 changes: 114 additions & 0 deletions src/axolotl/utils/data/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""data handling specific to DPO"""

import logging
from pathlib import Path
from typing import Any, List

import yaml
from datasets import concatenate_datasets, load_dataset, load_from_disk

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first

LOG = logging.getLogger("axolotl")


def _get_path(ds_hash, cfg):
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
)

return prepared_ds_path


def _load_preprocessed_ds(cfg, sub_cfg):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
dataset = None

# pylint: disable=duplicate-code
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))

return dataset


def _save_preprocessed_ds(cfg, sub_cfg, dataset):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)

if cfg.is_preprocess and is_main_process():
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset.save_to_disk(str(prepared_ds_path))


def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):
if ds_cfg["ds_type"] == "json":
for data_file in ds_cfg["data_files"]:
data_files = {ds_cfg["split"]: data_file}
ds = load_dataset( # pylint: disable=invalid-name
"json",
data_files=data_files,
split=ds_cfg["split"],
)
split_datasets.insert(i, ds)
else:
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
)
split_datasets.insert(i, ds)

for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set

return concatenate_datasets(split_datasets)

with zero_first(is_main_process()):
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)

eval_dataset = None
if cfg.test_datasets:
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
eval_dataset = None

if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)

return train_dataset, eval_dataset
232 changes: 232 additions & 0 deletions src/axolotl/utils/data/pretraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""data handling specific to pretraining"""

import functools
import logging
from collections import defaultdict
from typing import Callable, Dict, List, Optional

import torch
from datasets import Dataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase

from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing

LOG = logging.getLogger("axolotl")


def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer(
examples,
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)

# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)

for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)

buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)

if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)

ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}

LOG.debug(len(ret["input_ids"]))
return ret


def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
encode = functools.partial(
encode_packed_pretraining,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
LOG.debug("NOT shuffling merged pretraining datasets")

# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
# this is empty during streaming/pretraining
remove_columns = []
if dataset.features is None:
for first_row in dataset:
remove_columns = first_row.keys()
break
else:
remove_columns = dataset.features.keys()

dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
remove_columns=remove_columns,
)
return dataset


def encode_packed_pretraining(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = False,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]

train_dataset = process_pretraining_datasets_for_packing(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
)

sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=1,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=get_dataset_lengths(train_dataset),
)

chunked_data = defaultdict(list)

for batch in sampler:
for data in batch:
features = train_dataset[data]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)

for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))

return chunked_data
Loading

0 comments on commit e0fcef4

Please sign in to comment.