Skip to content

Commit

Permalink
Update gemma2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jul 3, 2024
1 parent d4616d8 commit 58f016a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def Gemma2Attention_fast_forward_inference(
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
self.half_head_dim = head_dim // 2
self. t = self.config.attn_logit_softcapping
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
Expand Down Expand Up @@ -329,9 +331,8 @@ def Gemma2Attention_fast_forward_inference(
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched

t = self.config.attn_logit_softcapping
A *= (1.0 / t); torch.tanh(A, out = A); A *= t; # Logit softcapping

A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping

A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)
Expand Down

0 comments on commit 58f016a

Please sign in to comment.