diff --git a/config_files/data_preparation/packed_cc_en_2048.yaml b/config_files/data_preparation/packed_cc_en_2048.yaml index 3f0f6023..36bfa3a0 100644 --- a/config_files/data_preparation/packed_cc_en_2048.yaml +++ b/config_files/data_preparation/packed_cc_en_2048.yaml @@ -15,4 +15,4 @@ tokenizer: config: tokenizer_model_file: /workspaces/modalities/data/tokenizer/sp_bpe_en/bpe_tokenizer.model padding: false - max_length: 2048 \ No newline at end of file + truncation: false diff --git a/config_files/data_preparation/packed_dataset_config.yaml b/config_files/data_preparation/packed_dataset_config.yaml index 7b999a88..9644e251 100644 --- a/config_files/data_preparation/packed_dataset_config.yaml +++ b/config_files/data_preparation/packed_dataset_config.yaml @@ -12,4 +12,4 @@ tokenizer: config: pretrained_model_name_or_path: data/tokenizer/hf_gpt2 padding: false - max_length: 512 + truncation: false diff --git a/config_files/text_generation/text_generation_config_torch.yaml b/config_files/text_generation/text_generation_config_torch.yaml index a639586f..3e84ce8d 100644 --- a/config_files/text_generation/text_generation_config_torch.yaml +++ b/config_files/text_generation/text_generation_config_torch.yaml @@ -90,4 +90,4 @@ tokenizer: config: pretrained_model_name_or_path: /raid/s3/opengptx/max_lue/modalities/data/tokenizer/hf_gpt2 padding: false - max_length: ${settings.context_length} \ No newline at end of file + truncation: false diff --git a/config_files/text_generation/text_generation_overfitted_de.yaml b/config_files/text_generation/text_generation_overfitted_de.yaml index 600d166e..34c398ef 100644 --- a/config_files/text_generation/text_generation_overfitted_de.yaml +++ b/config_files/text_generation/text_generation_overfitted_de.yaml @@ -93,4 +93,4 @@ tokenizer: config: pretrained_model_name_or_path: /raid/s3/opengptx/max_lue/modalities/data/tokenizer/hf_gpt2 padding: false - max_length: ${settings.context_length} \ No newline at end of file + truncation: false diff --git a/examples/getting_started/README.md b/examples/getting_started/README.md index 04bfee82..776a473d 100644 --- a/examples/getting_started/README.md +++ b/examples/getting_started/README.md @@ -58,7 +58,7 @@ tokenizer: config: pretrained_model_name_or_path: tokenizer padding: false - max_length: 512 + truncation: false ``` ### Step 1: Create Index diff --git a/examples/getting_started/example_dataset_config_test.yaml b/examples/getting_started/example_dataset_config_test.yaml index 1d6c389a..0dcf04b8 100644 --- a/examples/getting_started/example_dataset_config_test.yaml +++ b/examples/getting_started/example_dataset_config_test.yaml @@ -15,4 +15,4 @@ tokenizer: config: pretrained_model_name_or_path: tokenizer padding: false - max_length: 512 + truncation: false diff --git a/examples/getting_started/example_dataset_config_train.yaml b/examples/getting_started/example_dataset_config_train.yaml index 7811c8ce..93f66e19 100644 --- a/examples/getting_started/example_dataset_config_train.yaml +++ b/examples/getting_started/example_dataset_config_train.yaml @@ -15,4 +15,4 @@ tokenizer: config: pretrained_model_name_or_path: tokenizer padding: false - max_length: 512 + truncation: false diff --git a/examples/getting_started/example_text_generation_config.yaml b/examples/getting_started/example_text_generation_config.yaml index f992e26f..08001713 100644 --- a/examples/getting_started/example_text_generation_config.yaml +++ b/examples/getting_started/example_text_generation_config.yaml @@ -93,4 +93,4 @@ tokenizer: config: pretrained_model_name_or_path: tokenizer padding: false - max_length: ${settings.context_length} \ No newline at end of file + truncation: false diff --git a/notebooks/components.yaml b/notebooks/components.yaml index 48b790f4..0c0e6aeb 100644 --- a/notebooks/components.yaml +++ b/notebooks/components.yaml @@ -11,7 +11,7 @@ tokenizer: config: tokenizer_model_file: /workspaces/modalities/notebooks/tokenizer/unigram_tokenizer.model padding: false - max_length: 2048 + truncation: false train_dataset: component_key: dataset diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index e24ecc05..15bc61ea 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -4,7 +4,14 @@ import torch from omegaconf import OmegaConf -from pydantic import BaseModel, Field, FilePath, PositiveInt, field_validator, model_validator +from pydantic import ( + BaseModel, + Field, + FilePath, + PositiveInt, + field_validator, + model_validator, +) from torch.distributed.fsdp import ShardingStrategy from transformers import GPT2TokenizerFast from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast @@ -146,7 +153,9 @@ class StepLRSchedulerConfig(BaseModel): class OneCycleLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType - max_lr: Annotated[float, Field(strict=True, gt=0.0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] + max_lr: Annotated[float, Field(strict=True, gt=0.0)] | List[ + Annotated[float, Field(strict=True, gt=0.0)] + ] total_steps: Optional[Annotated[int, Field(strict=True, gt=0)]] = None epochs: Optional[Annotated[int, Field(strict=True, gt=0)]] = None steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None @@ -167,8 +176,12 @@ class OneCycleLRSchedulerConfig(BaseModel): @model_validator(mode="after") def check_totals_steps_and_epchs(self) -> "OneCycleLRSchedulerConfig": - if self.total_steps is None and (self.epochs is None or self.steps_per_epoch is None): - raise ValueError("Please define total_steps or (epochs and steps_per_epoch).") + if self.total_steps is None and ( + self.epochs is None or self.steps_per_epoch is None + ): + raise ValueError( + "Please define total_steps or (epochs and steps_per_epoch)." + ) return self @@ -227,9 +240,10 @@ def parse_sharding_strategy_by_name(cls, name): class PreTrainedHFTokenizerConfig(BaseModel): pretrained_model_name_or_path: str - max_length: Annotated[int, Field(strict=True, ge=0)] + max_length: Optional[Annotated[int, Field(strict=True, ge=0)]] = None truncation: bool = False padding: bool | str = False + special_tokens: Optional[Dict[str, str]] = None class PreTrainedSPTokenizerConfig(BaseModel): @@ -316,7 +330,9 @@ class DummyProgressSubscriberConfig(BaseModel): class RichProgressSubscriberConfig(BaseModel): train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field( + default_factory=list + ) global_num_seen_steps: int local_rank: int @@ -342,7 +358,11 @@ class RichResultSubscriberConfig(BaseModel): def load_app_config_dict(config_file_path: Path) -> Dict: def cuda_env_resolver_fun(var_name: str) -> int: int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK"] - return int(os.getenv(var_name)) if var_name in int_env_variable_names else os.getenv(var_name) + return ( + int(os.getenv(var_name)) + if var_name in int_env_variable_names + else os.getenv(var_name) + ) def modalities_env_resolver_fun(var_name: str) -> int: if var_name == "experiment_id": @@ -355,7 +375,9 @@ def node_env_resolver_fun(var_name: str) -> int: return os.cpu_count() OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True) - OmegaConf.register_new_resolver("modalities_env", modalities_env_resolver_fun, replace=True) + OmegaConf.register_new_resolver( + "modalities_env", modalities_env_resolver_fun, replace=True + ) OmegaConf.register_new_resolver("node_env", node_env_resolver_fun, replace=True) cfg = OmegaConf.load(config_file_path) diff --git a/src/modalities/models/gpt2/preprocess_dataset.py b/src/modalities/models/gpt2/preprocess_dataset.py deleted file mode 100644 index e89d591e..00000000 --- a/src/modalities/models/gpt2/preprocess_dataset.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -from itertools import chain - -from accelerate import Accelerator -from datasets import load_dataset -from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast - - -def main(): - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, and if the total_length < block_size - # we exclude this batch and return an empty dict. We could add padding if the - # model supported it instead of this drop, you can customize this part to your needs. - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - - def tokenize_function(examples): - return tokenizer(examples[text_column_name]) - - dataset_name = "wikitext-103-raw-v1" # "wikitext-2-raw-v1" - - accelerator = Accelerator(gradient_accumulation_steps=1) - tokenizer_file_path = os.path.join(os.path.dirname(__file__), "tokenizer.json") - tokenizer = GPT2TokenizerFast(tokenizer_file=tokenizer_file_path) - raw_datasets = load_dataset(path="wikitext", name=dataset_name) - column_names = raw_datasets["train"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - block_size = tokenizer.model_max_length - if block_size > 1024: - block_size = 1024 - # datasets.save_to_disk('wikitext-2-raw-v1') - gpt_version: str = "gpt2" - config = GPT2Config.from_pretrained(gpt_version, output_hidden_stages=False) - model = GPT2LMHeadModel.from_pretrained(gpt_version, config=config) - embedding_size = model.get_input_embeddings().weight.shape[0] - if len(tokenizer) > embedding_size: - model.resize_token_embeddings(len(tokenizer)) - - with accelerator.main_process_first(): - tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, num_proc=2, remove_columns=column_names) - - with accelerator.main_process_first(): - llm_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=2) - print(llm_datasets) - dataset_path = os.path.join(os.path.dirname(__file__), f"data/{dataset_name}-tokenized") - llm_datasets.save_to_disk(dataset_path) - - -if __name__ == "__main__": - main() diff --git a/src/modalities/tokenization/tokenizer_wrapper.py b/src/modalities/tokenization/tokenizer_wrapper.py index 158ebd04..22267840 100644 --- a/src/modalities/tokenization/tokenizer_wrapper.py +++ b/src/modalities/tokenization/tokenizer_wrapper.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List +from typing import Dict, List, Optional import sentencepiece as spm from transformers import AutoTokenizer @@ -25,9 +25,26 @@ def get_token_id(self, token: str) -> int: class PreTrainedHFTokenizer(TokenizerWrapper): def __init__( - self, pretrained_model_name_or_path: str, max_length: int, truncation: bool = True, padding: str = "max_length" + self, + pretrained_model_name_or_path: str, + truncation: bool = False, + padding: bool | str = False, + max_length: Optional[int] = None, + special_tokens: Optional[Dict[str, str]] = None, ) -> None: - self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path) + # also see here for the truncation and padding options and their effects: + # https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation + + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path + ) + if special_tokens is not None: + # TODO check if we always want to set + # replace_additional_special_tokens=False + self.tokenizer.add_special_tokens( + special_tokens_dict=special_tokens, + replace_additional_special_tokens=False, + ) self.max_length = max_length self.truncation = truncation self.padding = padding @@ -36,6 +53,10 @@ def __init__( def vocab_size(self): return self.tokenizer.vocab_size + @property + def special_tokens(self) -> Dict[str, str | List[str]]: + return self.tokenizer.special_tokens_map + def tokenize(self, text: str) -> List[int]: tokens = self.tokenizer.__call__( text, diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index f5facdb7..061127ee 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -1,6 +1,12 @@ +import numpy as np import pytest -from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer, PreTrainedSPTokenizer, TokenizerWrapper +from modalities.config.config import PreTrainedHFTokenizerConfig +from modalities.tokenization.tokenizer_wrapper import ( + PreTrainedHFTokenizer, + PreTrainedSPTokenizer, + TokenizerWrapper, +) def _assert_tokenization(tokenizer: TokenizerWrapper): @@ -9,12 +15,303 @@ def _assert_tokenization(tokenizer: TokenizerWrapper): assert len(token_ids) > 0 -def test_hf_tokenize(): - tokenizer_model_file = "data/tokenizer/hf_gpt2" - tokenizer = PreTrainedHFTokenizer( - pretrained_model_name_or_path=tokenizer_model_file, max_length=20, truncation=False, padding=False +@pytest.mark.parametrize( + "text,tokenizer_config,expected_length,expected_num_padding_tokens", + [ + # Test cases 1: Sequence is shorter than max_length, i.e., len(text) < max_length + # If padding="max_length", we want a sequence to be padded to the max_length, irrespective of the truncation flag + # and only if max_length is specified. + # If max_length is not specified, we pad to the max model input length (i.e., 1024 for the gpt2 model). + # NOTE: "AAAAAAAA" is a single token for the gpt2 tokenizer, there is no "A" sequence longer than that in the vocabulary. + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=False, + padding="max_length", + max_length=10, + special_tokens={"pad_token": "[PAD]"}, + ), + 10, + 4, + ), + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding="max_length", + max_length=10, + special_tokens={"pad_token": "[PAD]"}, + ), + 10, + 4, + ), + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=False, + padding="max_length", + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1024, + 1018, + ), + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding="max_length", + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1024, + 1018, + ), + # If padding=False, we want no padding to be applied, irrespective of the truncation flag and max_length., + # irrespective of the truncation flag + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=False, + padding=False, + max_length=10, + special_tokens={"pad_token": "[PAD]"}, + ), + 6, + 0, + ), + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding=False, + max_length=10, + special_tokens={"pad_token": "[PAD]"}, + ), + 6, + 0, + ), + # NOTE: This is the setting used for pretraining dataset tokenisation!!! + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=False, + padding=False, + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 6, + 0, + ), + ( + "AAAAAAAA" * 6, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding=False, + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 6, + 0, + ), + # Test cases 2: Sequence is longer than max_length, i.e., len(text) > max_length + # If truncation=True and len(text) max model input length + # NOTE: This is a typical case when tokenising the pretraining dataset!!! + ( + "AAAAAAAA" * 1030, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=False, + padding=False, + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1030, + 0, + ), + ( + "AAAAAAAA" * 1030, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=False, + padding=True, + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1030, + 0, + ), + ( + "AAAAAAAA" * 1030, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding=True, + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1024, + 0, + ), + # if we want to pad to max model input length we have to use padding="max_length", + # otherwise we will only pad to the longest sequence in the batch. + # see: https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation + ( + "AAAAAAAA" * 1020, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding=True, + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1020, + 0, + ), + ( + "AAAAAAAA" * 1020, + PreTrainedHFTokenizerConfig( + pretrained_model_name_or_path="data/tokenizer/hf_gpt2", + truncation=True, + padding="max_length", + max_length=None, + special_tokens={"pad_token": "[PAD]"}, + ), + 1024, + 4, + ), + ], +) +def test_hf_tokenize( + text: str, + tokenizer_config: PreTrainedHFTokenizerConfig, + expected_length: int, + expected_num_padding_tokens: int, +): + # also see here for the truncation and padding options and their effects: + # https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation + + tokenizer_config_dict = tokenizer_config.model_dump() + tokenizer = PreTrainedHFTokenizer(**tokenizer_config_dict) + + token_ids = tokenizer.tokenize(text) + + # make sure that the overall token sequence length is correct + assert len(token_ids) == expected_length + + # check number of non-padding tokens (token_id = 43488 corresponds to "AAAAAAAA") + assert sum(np.array(token_ids) == 43488) == ( + expected_length - expected_num_padding_tokens ) - _assert_tokenization(tokenizer) + + # check number of padding tokens + assert sum(np.array(token_ids) == 50257) == expected_num_padding_tokens @pytest.mark.skip(reason="Missing pretrained unigram sp tokenizer.") diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index 84718754..8f832e99 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -20,13 +20,6 @@ settings: paths: checkpointing_path: data/checkpoints -tokenizer: - component_key: tokenizer - variant_key: pretrained_hf_tokenizer - config: - pretrained_model_name_or_path: ./data/tokenizer/hf_gpt2 - max_length: ${settings.training.sequence_length} - collate_fn: component_key: collate_fn variant_key: gpt_2_llm_collator