diff --git a/tests/kfto/core/config_qlora.json b/tests/kfto/core/config_qlora.json new file mode 100644 index 00000000..bb18eb29 --- /dev/null +++ b/tests/kfto/core/config_qlora.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "/tmp/model/bloom-560m", + "training_data_path": "/etc/config/twitter_complaints_small.json", + "output_dir": "/tmp/out", + "save_model_dir": "/tmp/out", + "num_train_epochs": 1.0, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "save_strategy": "no", + "learning_rate": 1e-4, + "weight_decay": 0.0, + "lr_scheduler_type": "cosine", + "include_tokens_per_second": true, + "response_template": "\n### Label:", + "dataset_text_field": "output", + "use_flash_attn": false, + "peft_method": "qlora", + "target_modules": ["all-linear"], + "use_4bit": true, + "bnb_4bit_compute_dtype": "float16", + "bnb_4bit_quant_type": "nf4", + "use_nested_quant": false, + "fp16": false, + "bf16": false +} \ No newline at end of file diff --git a/tests/kfto/core/kfto_kueue_sft_test.go b/tests/kfto/core/kfto_kueue_sft_test.go index 6d9f0500..9897290a 100644 --- a/tests/kfto/core/kfto_kueue_sft_test.go +++ b/tests/kfto/core/kfto_kueue_sft_test.go @@ -37,6 +37,9 @@ func TestPytorchjobWithSFTtrainerFinetuning(t *testing.T) { func TestPytorchjobWithSFTtrainerLoRa(t *testing.T) { runPytorchjobWithSFTtrainer(t, "config_lora.json") } +func TestPytorchjobWithSFTtrainerQLoRa(t *testing.T) { + runPytorchjobWithSFTtrainer(t, "config_qlora.json") +} func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string) { test := With(t)