@@ -71,12 +71,14 @@ def Gemma2DecoderLayer_fast_forward(
71
71
use_cache = use_cache ,
72
72
padding_mask = padding_mask ,
73
73
)
74
+ hidden_states = fast_rms_layernorm_inference_gemma (self .post_attention_layernorm , hidden_states , out_weight )
74
75
hidden_states += residual
75
76
76
77
# Fully Connected
77
78
residual = hidden_states
78
- hidden_states = fast_rms_layernorm_inference_gemma (self .post_attention_layernorm , hidden_states , out_weight )
79
+ hidden_states = fast_rms_layernorm_inference_gemma (self . pre_feedforward_layernorm , hidden_states , out_weight )
79
80
hidden_states = fast_geglu_inference (self .mlp , hidden_states )
81
+ hidden_states = fast_rms_layernorm_inference_gemma (self .post_feedforward_layernorm , hidden_states , out_weight )
80
82
hidden_states += residual
81
83
else :
82
84
residual = hidden_states
@@ -151,11 +153,13 @@ def Gemma2Model_fast_forward_inference(
151
153
attention_mask = attention_mask ,
152
154
do_prefill = not hasattr (decoder_layer .self_attn , "paged_attention" ),
153
155
)
156
+ hidden_states = fast_rms_layernorm_inference_gemma (decoder_layer .post_attention_layernorm , hidden_states , out_weight )
154
157
hidden_states += residual
155
158
156
159
residual = hidden_states
157
- hidden_states = fast_rms_layernorm_inference_gemma (decoder_layer .post_attention_layernorm , hidden_states , out_weight )
160
+ hidden_states = fast_rms_layernorm_inference_gemma (decoder_layer . pre_feedforward_layernorm , hidden_states , out_weight )
158
161
hidden_states = fast_geglu_inference (decoder_layer .mlp , hidden_states )
162
+ hidden_states = fast_rms_layernorm_inference_gemma (decoder_layer .post_feedforward_layernorm , hidden_states , out_weight )
159
163
hidden_states += residual
160
164
161
165
next_decoder_cache .append (present_key_value )
0 commit comments