Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement efficient packing without cross-contamination attention #4224

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

chuan298
Copy link

@chuan298 chuan298 commented Jun 11, 2024

What does this PR do?

Update 15/6/2024: Add support packing for eager and sdpa


Fixes #2289

Implement efficient packing without cross-contamination attention
Taking inspiration from some repository as axolotl and functionary, I applied packing sequences more effectively, enabling the model to learn samples more efficiently without attending to other samples within the same pack. Now I only support this implement for sft with flash_attention_2.

Example training config:

### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
flash_attn: fa2

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
efficient_packing: true

### output
output_dir: saves/llama3-8b/lora/sft
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

Before submitting

@hiyouga hiyouga added the pending This problem is yet to be addressed label Jun 12, 2024
@hiyouga hiyouga mentioned this pull request Jun 15, 2024
1 task
@AlongWY
Copy link
Contributor

AlongWY commented Jun 20, 2024

是否应该考虑使用 varlen_flash_atten 实现?

@@ -33,6 +33,9 @@ def run_sft(
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

if data_args.efficient_packing:
configure_packing(model.config, model_args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we do configure_packing in llamafactory.model.patcher?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I just edited it

@@ -66,6 +66,21 @@

SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}

SUPPORTED_CLASS_FOR_MULTIPACK = [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it "efficient_packing" rather than "multipack"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I just fixed.

@chuan298
Copy link
Author

是否应该考虑使用 varlen_flash_atten 实现?

Hi @AlongWY , The models in transformers have used flash_attn_varlen_func by default when passing attention_mask. I just made a slight change to the attention_mask when packing sequences and returned indices, cu_seqlens, and max_seqlen_in_batch corresponding to the modified attention_mask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pending This problem is yet to be addressed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sft_packing实现的问题
3 participants