Skip to content

Commit 3a6573f

Browse files
committed
Update gemma2.py
1 parent 7c5d0ef commit 3a6573f

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

unsloth/models/gemma2.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def Gemma2DecoderLayer_fast_forward(
7171
use_cache=use_cache,
7272
padding_mask=padding_mask,
7373
)
74+
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
7475
hidden_states += residual
7576

7677
# Fully Connected
7778
residual = hidden_states
78-
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
79+
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
7980
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
81+
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
8082
hidden_states += residual
8183
else:
8284
residual = hidden_states
@@ -151,11 +153,13 @@ def Gemma2Model_fast_forward_inference(
151153
attention_mask = attention_mask,
152154
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
153155
)
156+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
154157
hidden_states += residual
155158

156159
residual = hidden_states
157-
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
160+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
158161
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
162+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
159163
hidden_states += residual
160164

161165
next_decoder_cache.append(present_key_value)

0 commit comments

Comments
 (0)