From bed95c9fda71fb3f60c5a2d8a131d971966db7b3 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sat, 15 Jun 2024 07:21:41 +0000 Subject: [PATCH 1/3] Implementing a basic chat_template strategy for DPO datasets This mimics the sft chat_template strategy such that users can: * Specify the messages field * Specify the per message role and content fields * speicfy the chosen and rejected fields * Let the tokenizer construct the raw prompt * Ensure the chosen and rejected fields don't have any prefix tokens --- examples/llama-3/instruct-dpo-lora-8b.yml | 81 +++++++++++++++++ .../prompt_strategies/dpo/chat_template.py | 74 +++++++++++++++ src/axolotl/utils/data/rl.py | 1 + src/axolotl/utils/tokenization.py | 2 +- .../test_dpo_chat_templates.py | 89 +++++++++++++++++++ 5 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 examples/llama-3/instruct-dpo-lora-8b.yml create mode 100644 src/axolotl/prompt_strategies/dpo/chat_template.py create mode 100644 tests/prompt_strategies/test_dpo_chat_templates.py diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml new file mode 100644 index 000000000..14febb810 --- /dev/null +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -0,0 +1,81 @@ +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: llama3 +rl: dpo +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + chat_template: llama3 + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_field_role: role + message_field_content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py new file mode 100644 index 000000000..33ce341c0 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -0,0 +1,74 @@ +from axolotl.utils.chat_templates import chat_templates + + +def default( + cfg, dataset_idx=0, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + ds_cfg = cfg["datasets"][dataset_idx] + chat_template_str = chat_templates(cfg.chat_template) + + field_messages = ds_cfg.get("field_messages", "messages") + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + field_message_role = ds_cfg.get("message_field_role", "role") + field_message_content = ds_cfg.get("message_field_content", "content") + role_map_inv = ds_cfg.get( + "roles", + { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ) + role_map = {} + for target, sources in role_map_inv.items(): + for source in sources: + role_map[source] = target + + def transform_fn(sample, tokenizer=None): + messages = sample[field_messages] + messages = [ + { + "role": m[field_message_role], + "content": m[field_message_content], + } + for m in messages + ] + chosen = { + "role": sample[field_chosen][field_message_role], + "content": sample[field_chosen][field_message_content], + } + rejected = { + "role": sample[field_rejected][field_message_role], + "content": sample[field_rejected][field_message_content], + } + + result = {} + result["prompt"] = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=False, + ) + + result["chosen"] = tokenizer.apply_chat_template( + [chosen], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + ) + chosen_strip_index = result["chosen"].find(chosen["content"]) + result["chosen"] = result["chosen"][chosen_strip_index:] + + result["rejected"] = tokenizer.apply_chat_template( + [rejected], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + ) + rejected_strip_index = result["rejected"].find(rejected["content"]) + result["rejected"] = result["rejected"][rejected_strip_index:] + + return result + + return transform_fn diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 7416ca28b..d0324e1eb 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,4 +1,5 @@ """data handling specific to DPO""" + import inspect import logging from functools import partial diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 845296b7a..f353aebec 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): """Helper function to process and color tokens.""" colored_tokens = [ color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) - for token in tokenizer.encode(tokens) + for token in tokenizer.encode(tokens, add_special_tokens=False) ] return colored_tokens diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py new file mode 100644 index 000000000..d267b0d0c --- /dev/null +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -0,0 +1,89 @@ +""" +tests for chat_template prompt strategy +""" + +import unittest + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.prompt_strategies.dpo.chat_template import default +from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "hello", + }, + { + "role": "user", + "content": "goodbye", + }, + ], + "chosen": { + "role": "assistant", + "content": "goodbye", + }, + "rejected": { + "role": "assistant", + "content": "party on", + }, + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + tokenizer.eos_token = "<|eot_id|>" + + return tokenizer + + +class TestAssistantChatTemplateLlama3: + """ + Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. + """ + + def test_llama3(self, llama3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "chat_template": "llama3", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + +if __name__ == "__main__": + unittest.main() From 61000d543fb133af2e49e12044a8d903223359cb Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 2 Jul 2024 19:37:48 +0000 Subject: [PATCH 2/3] Adding additional dpo chat template unittests --- .../prompt_strategies/dpo/chat_template.py | 10 ++- .../test_dpo_chat_templates.py | 71 ++++++++++++++++++- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 33ce341c0..4f2f14098 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -1,3 +1,7 @@ +""" +DPO prompt strategies for using tokenizer chat templates. +""" + from axolotl.utils.chat_templates import chat_templates @@ -29,17 +33,17 @@ def transform_fn(sample, tokenizer=None): messages = sample[field_messages] messages = [ { - "role": m[field_message_role], + "role": role_map[m[field_message_role]], "content": m[field_message_content], } for m in messages ] chosen = { - "role": sample[field_chosen][field_message_role], + "role": role_map[sample[field_chosen][field_message_role]], "content": sample[field_chosen][field_message_content], } rejected = { - "role": sample[field_rejected][field_message_role], + "role": role_map[sample[field_rejected][field_message_role]], "content": sample[field_rejected][field_message_content], } diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index d267b0d0c..da8719fd0 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -9,7 +9,6 @@ from transformers import AutoTokenizer from axolotl.prompt_strategies.dpo.chat_template import default -from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault @@ -46,6 +45,39 @@ def fixture_assistant_dataset(): ) +@pytest.fixture(name="custom_assistant_dataset") +def fixture_custom_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversation": [ + { + "speaker": "human", + "text": "hello", + }, + { + "speaker": "agent", + "text": "hello", + }, + { + "speaker": "human", + "text": "goodbye", + }, + ], + "better": { + "speaker": "agent", + "text": "goodbye", + }, + "worse": { + "speaker": "agent", + "text": "party on", + }, + } + ] + ) + + @pytest.fixture(name="llama3_tokenizer") def fixture_llama3_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") @@ -59,7 +91,7 @@ class TestAssistantChatTemplateLlama3: Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. """ - def test_llama3(self, llama3_tokenizer, assistant_dataset): + def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): # pylint: disable=duplicate-code transform_fn = default( DictDefault( @@ -84,6 +116,41 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): assert result["chosen"] == "goodbye<|eot_id|>" assert result["rejected"] == "party on<|eot_id|>" + def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "better", + "field_rejected": "worse", + "message_field_role": "speaker", + "message_field_content": "text", + "roles": { + "user": ["human"], + "assistant": ["agent"], + "system": ["sys"], + }, + } + ], + } + ) + ) + result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + if __name__ == "__main__": unittest.main() From 6654826784dcc02382f0180554a255768caec143 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Fri, 5 Jul 2024 15:53:11 +0000 Subject: [PATCH 3/3] Rename test class --- tests/prompt_strategies/test_dpo_chat_templates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index da8719fd0..cca48b1cf 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -86,7 +86,7 @@ def fixture_llama3_tokenizer(): return tokenizer -class TestAssistantChatTemplateLlama3: +class TestAssistantDPOChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. """