From f8213cf13f91164bafc241ea094ff28e8fe37a51 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 2 Jul 2024 21:09:27 -0700 Subject: [PATCH] pad token --- unsloth/models/_utils.py | 19 +++++++++---------- unsloth/models/llama.py | 4 ++-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index bd512fc0f..90df68827 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -26,6 +26,15 @@ import logging logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) +# Remove generation "Setting `pad_token_id` to `eos_token_id`:..." +def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: + if pad_token_id is None and eos_token_id is not None: + # logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + return pad_token_id +from transformers import GenerationMixin +GenerationMixin._get_pad_token_id = _get_pad_token_id + import bitsandbytes as bnb from transformers.models.llama.modeling_llama import logger from transformers import AutoTokenizer @@ -101,16 +110,6 @@ def is_big_gpu(index): torch._inductor.utils.is_big_gpu = is_big_gpu -# Remove generation "Setting `pad_token_id` to `eos_token_id`:..." -def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: - if pad_token_id is None and eos_token_id is not None: - # logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - pad_token_id = eos_token_id - return pad_token_id -from transformers import GenerationMixin -GenerationMixin._get_pad_token_id = _get_pad_token_id - - # Torch compile arguments torch_compile_arguments = [ "config.dce = True", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index adee029a8..94ef9ed41 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -15,6 +15,8 @@ import torch import gc from typing import Optional, Tuple, List, Union +from ._utils import * +from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers.models.llama.modeling_llama import ( logger, @@ -25,8 +27,6 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from ..kernels import * -from ._utils import * -from ._utils import __version__ from ..tokenizer_utils import * if HAS_FLASH_ATTENTION: from flash_attn import flash_attn_func