We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 66f35fd commit acf9771Copy full SHA for acf9771
unsloth/models/llama.py
@@ -858,11 +858,6 @@ def _CausalLM_fast_forward(
858
logits = self.lm_head(hidden_states.to(lm_head.dtype))
859
pass
860
logits = logits.to(self.config.torch_dtype)
861
- if self.config.final_logit_softcapping is not None:
862
- logits = logits / self.config.final_logit_softcapping
863
- logits = torch.tanh(logits)
864
- logits = logits * self.config.final_logit_softcapping
865
- pass
866
867
loss = None
868
if labels is not None:
@@ -876,6 +871,7 @@ def _CausalLM_fast_forward(
876
871
loss = fast_cross_entropy_loss(
877
872
logits = shift_logits,
878
873
labels = shift_labels,
874
+ logit_softcapping = getattr(self.config, "final_logit_softcapping", 0),
879
875
)
880
881
0 commit comments