Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label Leakage in Gemma 2 Finetuning #31673

Closed
2 of 4 tasks
hiyouga opened this issue Jun 27, 2024 · 1 comment · Fixed by #31674
Closed
2 of 4 tasks

Label Leakage in Gemma 2 Finetuning #31673

hiyouga opened this issue Jun 27, 2024 · 1 comment · Fixed by #31674

Comments

@hiyouga
Copy link
Contributor

hiyouga commented Jun 27, 2024

System Info

  • transformers version: 4.42.1
  • Platform: Linux-5.4.0-147-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • PyTorch version (GPU?): 2.3.0+cu121 (True)

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Loss curves

We initially found that the training loss quickly dropped to 0 when fine-tuning the gemma2 models.

image

This phenomenon is very similar to label leakage.

Breakpoints

We changed the following lines of Gemma2DecoderLayer to the below form.

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window
)
if attention_mask.shape[1] <= 1: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]

        if self.is_sliding and attention_mask is not None:  # efficient SDPA and no padding
            print("before", attention_mask)
            attention_mask = attention_mask * torch.tril(
                torch.ones_like(attention_mask), diagonal=-self.sliding_window
            )
            print("after", attention_mask)
            if attention_mask.shape[1] <= 1:  # when decoding
                attention_mask = attention_mask[:, -self.sliding_window :]

Then we execute the following code to trigger it.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it", torch_dtype="auto", device_map="auto", attn_implementation="eager"
)
inputs = torch.tensor([[0, 1, 2]]).to(model.device)
labels = inputs.clone()
outputs = model(input_ids=inputs, labels=labels)

The result is:

before tensor([[[[ 0.0000e+00, -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]]], device='cuda:0',
       dtype=torch.bfloat16)
after tensor([[[[0., -0., -0.],
          [0., 0., -0.],
          [0., 0., 0.]]]], device='cuda:0', dtype=torch.bfloat16)
before tensor([[[[ 0.0000e+00, -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]]], device='cuda:0',
       dtype=torch.bfloat16)
after tensor([[[[0., -0., -0.],
          [0., 0., -0.],
          [0., 0., 0.]]]], device='cuda:0', dtype=torch.bfloat16)
...

Expected behavior

The attention mask should not be all zeros; otherwise, it will use full attention instead of causal attention when calculating the loss, leading to label leakage.

@hiyouga hiyouga changed the title Label leakage in Gemma 2 Finetuning Label Leakage in Gemma 2 Finetuning Jun 27, 2024
@ArthurZucker
Copy link
Collaborator

Great catch, and sorry that it affected you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants