diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index cf0aedab2..78e43746b 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -13,13 +13,7 @@ # limitations under the License. from .llama import * -from .gemma import ( - GemmaFixedRotaryEmbedding, - fast_geglu_inference, - fast_rms_layernorm_inference_gemma, -) from ._utils import __version__ -from ._utils import torch_compile_options from transformers.models.gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -45,255 +39,22 @@ pass -from math import sqrt as math_sqrt -KV_CACHE_INCREMENT = 256 # KV Cache update size -torch_nn_functional_softmax = torch.nn.functional.softmax - -def Gemma2Attention_fast_forward_inference( - self, - hidden_states: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]], - position_ids, - do_prefill = False, - attention_mask = None, - sliding_window = None, -): - """ - https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406 - Fast inference using KV cache. - QK^T can be computed in 4 chunks - - [Q, q] @ [K, k].T where q, k are the new tokens. - [QK^T, Qk^T] - [qK^T, qk^T] - - Since the attention mask wipes Qk^T, we just get - [QK^T, 0] - [qK^T, qk^T] - - Since softmax is row-wise, we get - softmax([QK^T, 0]) - softmax([qK^T, qk^T]) - - We then multiply by [V] - [v] - softmax([QK^T, 0]) [softmax(QK^T)V] * - softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]] - - But notice * [softmax(QK^T)V] is just the last attention. - We just need to compute the last final row. - - This means we can pass in a row of Q, but we need to - remember K and V, which are called the KV cache. - """ - Xn = hidden_states - bsz, _, hd = hidden_states.size() - K1, V1 = past_key_value - dtype = Xn.dtype - - n_heads = self.num_heads - n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads - head_dim = self.head_dim - attention_size = n_heads*head_dim - # assert(n_kv_heads * n_groups == n_heads) - seq_len = K1.shape[-2] - kv_seq_len = seq_len + 1 - - # Prefill phase - # if not hasattr(self, "paged_attention"): - if do_prefill: - self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") - self.paged_attention_K = self.paged_attention[:,0] - self.paged_attention_V = self.paged_attention[:,1] - self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) - self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) - self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") - self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") - self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") - # Gemma2 style scaling - self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads) - self.half_head_dim = head_dim // 2 - elif kv_seq_len >= self.paged_attention.shape[0]: - self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) - self.paged_attention_K = self.paged_attention[:,0] - self.paged_attention_V = self.paged_attention[:,1] - self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) - pass - - Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) - Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) - Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) - Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) - Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) - Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) - - # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) - # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) - h = self.half_head_dim - - RH_Q = self.RH_Q - RH_Q[:,:,:,:h] = Qn[:,:,:,h:] - RH_Q[:,:,:,h:] = Qn[:,:,:,:h] - torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) - Qn *= cos - Qn.addcmul_(RH_Q, sin) - - RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") - RH_K[:,:,:,:h] = Kn[:,:,:,h:] - RH_K[:,:,:,h:] = Kn[:,:,:,:h] - torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) - Kn *= cos - Kn.addcmul_(RH_K, sin) - - # New KV cache - # Kn = torch.cat([K1, Kn], dim = 2) - # Vn = torch.cat([V1, Vn], dim = 2) - self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) - self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) - Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) - Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) - - # Handle sliding windows - if sliding_window is not None and kv_seq_len > sliding_window: - # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 - slicing_tokens = 1 - sliding_window - Knn = Kn[:, :, slicing_tokens:, :]#.contiguous() - Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous() - else: - Knn, Vnn = Kn, Vn - pass - - # Grouped query attention - _, _, cached_len, _ = Knn.shape - if n_groups != 1: - Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) - Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) - Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) - Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) - pass - # else: - # Knn, Vnn = Knn, Vnn - # pass - - # Attention - if bsz == 1: - Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 - # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows - A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) - # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched - A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) - A = torch.matmul(A, Vnn, out = Qn) - else: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) - pass - A = A.transpose(1, 2) - A = A.reshape(bsz, 1, attention_size) - A = fast_linear_forward(self.o_proj, A, out = self.temp_QA[1][:,:,:self.hidden_size]) - return A, (Kn, Vn) -pass - - -# SDPA but with logit softcapping (-50, 50) -@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) -def gemma2_scaled_dot_product_attention(Q, K, V, mask, self, bsz, q_len): - n_heads = self.num_heads - head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads - n_groups = self.num_key_value_groups - - # Grouped query attention - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) - K = K.reshape(bsz, n_heads, q_len, head_dim) - V = V.reshape(bsz, n_heads, q_len, head_dim) - - # Gemma2 logit softcapping - # 50 * tanh(Q @ K.T / sqrt(s) / 50) bounds it to (-50, 50) - t = self.config.attn_logit_softcapping - s = self.config.hidden_size // self.config.num_attention_heads - - # Must downcast like in Keras - Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) - A = t * torch.tanh(torch.matmul(Q, K.transpose(2, 3)) / t) - - # Add SWA or Global Attention matrix - A += mask[:q_len, :q_len] - A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32) - A = A.to(Q.dtype) - A = torch.matmul(A, V) - return A -pass - - -def Gemma2Attention_fast_forward( - self, - hidden_states: torch.Tensor, - causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, - *args, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - # Clear inference - if hasattr(self, "paged_attention"): - del self.paged_attention_K - del self.paged_attention_V - del self.paged_attention - del self.temp_QA - del self.temp_KV - del self.RH_Q - del self.attention - pass - - bsz, q_len, _ = hidden_states.size() - - n_heads = self.num_heads - n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads - head_dim = self.head_dim - assert(n_kv_heads * n_groups == n_heads) - - Q, K, V = self.apply_qkv(self, hidden_states) - Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) - K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) - V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) - - kv_seq_len = K.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached - Q, K = fast_rope_embedding(Q, K, cos, sin) - else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) - Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) - pass - - if past_key_value is not None: - K = torch.cat([past_key_value[0], K], dim = 2) - V = torch.cat([past_key_value[1], V], dim = 2) - pass - past_key_value = (K, V) if use_cache else None +torch_nn_functional_gelu = torch.nn.functional.gelu +def fast_geglu_inference(self, X): + # gate = self.gate_proj(X) + # up = self.up_proj(X) + bsz, _, hd = X.shape + # mlp_size = self.config.intermediate_size + # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - # Attention module - A = gemma2_scaled_dot_product_attention(Q, K, V, causal_mask, self, bsz, q_len) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2) + gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) + up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) + gate = torch_nn_functional_gelu(gate, approximate = "tanh") + gate *= up - attn_output = A.reshape(bsz, q_len, n_heads*head_dim) - attn_output = self.apply_o(self, attn_output) - attn_weights = None - return attn_output, attn_weights, past_key_value + # X = self.down_proj(gate) + down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd]) + return down pass @@ -326,14 +87,12 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) hidden_states += residual # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(self.pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) hidden_states = fast_geglu_inference(self.mlp, hidden_states) - hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight) hidden_states += residual else: residual = hidden_states @@ -348,14 +107,12 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm(self.pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass @@ -377,7 +134,6 @@ def Gemma2Model_fast_forward_inference( position_ids, attention_mask = None, ): - sliding_window = self.config.sliding_window out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) @@ -389,19 +145,11 @@ def Gemma2Model_fast_forward_inference( bsz, q_len, hd = hidden_states.shape seq_len = past_key_values[0][0].shape[-2] if bsz != 1: - swa_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (bsz, q_len), hidden_states, seq_len, - sliding_window = getattr(self.config, "sliding_window", None), - ) - global_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (bsz, q_len), - hidden_states, - seq_len, - sliding_window = None, ) pass @@ -409,22 +157,19 @@ def Gemma2Model_fast_forward_inference( for idx, decoder_layer in enumerate(self.model.layers): residual = hidden_states hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) - hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( + hidden_states, present_key_value = LlamaAttention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, past_key_value = past_key_values[idx], position_ids = position_ids, attention_mask = attention_mask, do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), - sliding_window = swa_attention_mask if (idx % 2 == 0) else global_attention_mask ) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) hidden_states += residual residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight) hidden_states += residual next_decoder_cache.append(present_key_value) @@ -440,17 +185,68 @@ def Gemma2Model_fast_forward_inference( pass +# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45 +# Formulates cos and sin differently from Llama! +class GemmaFixedRotaryEmbedding(torch.nn.Module): + # Fixes https://github.com/huggingface/transformers/pull/28837 + # https://github.com/microsoft/DeepSpeed/issues/4932 + # The precision of RoPE buffers is not correct, so we cast to int64. + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype()) + pass + + def _set_cos_sin_cache(self, seq_len, device, dtype): + # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and + # in FP32. They are applied (multiplied) in FP32 as well. + self.max_seq_len_cached = seq_len + + # The difference is we do division explicity instead of t * (1/x) ie we do t/x. + freq_exponents = (2.0 / self.dim) * ( + torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float() + ) + timescale = self.base**freq_exponents + positions = torch.arange(self.max_seq_len_cached, device = "cpu", dtype = torch.int64).float() + radians_new = positions[..., None] / timescale[None, None, :] + radians_new = radians_new.squeeze(0) + + emb = torch.cat((radians_new, radians_new), dim = -1) + # We must do RoPE in float32! + cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype) + self.register_buffer("cos_cached", cos, persistent = False) + self.register_buffer("sin_cached", sin, persistent = False) + pass + + def forward(self, x, position_ids=None, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + pass +pass + + class FastGemma2Model(FastLlamaModel): @staticmethod def pre_patch(): - Gemma2Attention .forward = Gemma2Attention_fast_forward - Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward - Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward + Gemma2Attention .forward = LlamaAttention_fast_forward + Gemma2SdpaAttention .forward = LlamaAttention_fast_forward + Gemma2FlashAttention2.forward = LlamaAttention_fast_forward Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward Gemma2Model .forward = LlamaModel_fast_forward Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) - PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward + PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a6946377c..fc0dfbb3d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -644,22 +644,22 @@ def LlamaModel_fast_forward( pass # Gemma2 has alternating SWA and global attn - if IS_GEMMA2 and not hasattr(self, "SWA_mask"): - from transformers.modeling_attn_mask_utils import AttentionMaskConverter - n = self.config.max_position_embeddings - self.SWA_mask = AttentionMaskConverter( - is_causal = True, - sliding_window = self.config.sliding_window, - )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ - .squeeze(0).squeeze(0) - - self.GA_mask = AttentionMaskConverter( - is_causal = True, - )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ - .squeeze(0).squeeze(0) - pass + # if IS_GEMMA2 and not hasattr(self, "SWA_mask"): + # from transformers.modeling_attn_mask_utils import AttentionMaskConverter + # n = self.config.max_position_embeddings + # self.SWA_mask = AttentionMaskConverter( + # is_causal = True, + # sliding_window = self.config.sliding_window, + # )\ + # .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + # .squeeze(0).squeeze(0) + + # self.GA_mask = AttentionMaskConverter( + # is_causal = True, + # )\ + # .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + # .squeeze(0).squeeze(0) + # pass # Go through every layer! for idx, decoder_layer in enumerate(self.layers): @@ -668,7 +668,7 @@ def LlamaModel_fast_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None mask = causal_mask - if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask + # if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask if offloaded_gradient_checkpointing: hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(