Skip to content

Commit

Permalink
feat: support custom tokenizer (foundation-model-stack#229)
Browse files Browse the repository at this point in the history
* feat: support custom tokenizer

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: retain back tokenizer check

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant authored Jul 13, 2024
1 parent e85eddb commit 77a195d
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 20 deletions.
39 changes: 27 additions & 12 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
prompt_tuning_init="RANDOM",
num_virtual_tokens=8,
prompt_tuning_init_text="hello",
tokenizer_name_or_path=MODEL_NAME,
)

PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
Expand Down Expand Up @@ -175,7 +174,12 @@ def test_run_causallm_pt_and_inference():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
Expand Down Expand Up @@ -208,7 +212,12 @@ def test_run_causallm_pt_and_inference_with_formatting_data():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
Expand Down Expand Up @@ -239,7 +248,12 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
Expand All @@ -261,7 +275,6 @@ def test_run_causallm_pt_init_text():
tuning_config = peft_config.PromptTuningConfig(
prompt_tuning_init="TEXT",
prompt_tuning_init_text="hello",
tokenizer_name_or_path=MODEL_NAME,
)

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, tuning_config)
Expand All @@ -270,7 +283,12 @@ def test_run_causallm_pt_init_text():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", tuning_config)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
)


invalid_params_map = [
Expand Down Expand Up @@ -364,7 +382,7 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA", base_lora_args)
_validate_adapter_config(adapter_config, "LORA")

for module in expected:
assert module in adapter_config.get("target_modules")
Expand Down Expand Up @@ -431,14 +449,11 @@ def _get_adapter_config(dir_path):
return json.load(f)


def _validate_adapter_config(adapter_config, peft_type, tuning_config):
def _validate_adapter_config(adapter_config, peft_type, tokenizer_name_or_path=None):
assert adapter_config.get("task_type") == "CAUSAL_LM"
assert adapter_config.get("peft_type") == peft_type
assert (
(
adapter_config.get("tokenizer_name_or_path")
== tuning_config.tokenizer_name_or_path
)
(adapter_config.get("tokenizer_name_or_path") == tokenizer_name_or_path)
if peft_type == "PROMPT_TUNING"
else True
)
Expand Down
11 changes: 11 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ class ModelArguments:
the given number after tokenizer modifications."
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to custom tokenizer.\
If not provided it defaults to model_name_or_path"
},
)

def __post_init__(self):
if not self.tokenizer_name_or_path:
self.tokenizer_name_or_path = self.model_name_or_path


@dataclass
Expand Down
4 changes: 0 additions & 4 deletions tuning/config/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,9 @@ class PromptTuningConfig:
prompt_tuning_init_text (`str`, *optional*):
The text to initialize the prompt embedding. \
Only used if `prompt_tuning_init` is `TEXT`.
tokenizer_name_or_path (`str`, *optional*):
The name or path of the tokenizer. \
Only used if `prompt_tuning_init` is `TEXT`.
num_virtual_tokens (`int`): The number of virtual tokens to use.
"""

prompt_tuning_init: str = "TEXT"
num_virtual_tokens: int = 8
prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:"
tokenizer_name_or_path: str = "llama-7b-hf"
6 changes: 4 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ def train(

# TODO: Move these to a config as well
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=train_args.cache_dir, use_fast=True
model_args.tokenizer_name_or_path, cache_dir=train_args.cache_dir, use_fast=True
)

# Calculate and save additional metrics to track later.
additional_metrics["model_load_time"] = time.time() - model_load_time

peft_config = get_hf_peft_config(task_type, peft_config)
peft_config = get_hf_peft_config(
task_type, peft_config, model_args.tokenizer_name_or_path
)

# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
Expand Down
7 changes: 5 additions & 2 deletions tuning/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def create_tuning_config(peft_method, **kwargs):
return tune_config


def get_hf_peft_config(task_type, tuning_config):
def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path):
"""Return HF PEFT config for tuning based on type of tuning config passed
Args:
task_type: str
tuning_config: peft_config.LoraConfig | peft_config.PromptTuningConfig | None
tokenizer_name_or_path: str
Return: HF PEFT config or None
"""
if isinstance(tuning_config, peft_config.LoraConfig):
Expand All @@ -85,7 +86,9 @@ def get_hf_peft_config(task_type, tuning_config):
hf_peft_config = LoraConfig(task_type=task_type, **lora_config)
elif isinstance(tuning_config, peft_config.PromptTuningConfig):
hf_peft_config = PromptTuningConfig(
task_type=task_type, **asdict(tuning_config)
task_type=task_type,
tokenizer_name_or_path=tokenizer_name_or_path,
**asdict(tuning_config),
)
else:
hf_peft_config = None # full parameter tuning
Expand Down

0 comments on commit 77a195d

Please sign in to comment.