Skip to content

Commit

Permalink
Update gemma2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jun 30, 2024
1 parent f2777ed commit 73148cd
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,7 @@ def Gemma2DecoderLayer_forward(
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
Expand All @@ -292,13 +289,13 @@ def Gemma2DecoderLayer_forward(
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down

0 comments on commit 73148cd

Please sign in to comment.