Skip to content

Commit

Permalink
pad token
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jul 3, 2024
1 parent 018d632 commit f8213cf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
19 changes: 9 additions & 10 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)

# Remove generation "Setting `pad_token_id` to `eos_token_id`:..."
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
# logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id
from transformers import GenerationMixin
GenerationMixin._get_pad_token_id = _get_pad_token_id

import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
Expand Down Expand Up @@ -101,16 +110,6 @@ def is_big_gpu(index):
torch._inductor.utils.is_big_gpu = is_big_gpu


# Remove generation "Setting `pad_token_id` to `eos_token_id`:..."
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
# logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id
from transformers import GenerationMixin
GenerationMixin._get_pad_token_id = _get_pad_token_id


# Torch compile arguments
torch_compile_arguments = [
"config.dce = True",
Expand Down
4 changes: 2 additions & 2 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch
import gc
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers.models.llama.modeling_llama import (
logger,
Expand All @@ -25,8 +27,6 @@
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ..kernels import *
from ._utils import *
from ._utils import __version__
from ..tokenizer_utils import *
if HAS_FLASH_ATTENTION:
from flash_attn import flash_attn_func
Expand Down

0 comments on commit f8213cf

Please sign in to comment.