Skip to content

Commit acf9771

Browse files
committed
Update llama.py
1 parent 66f35fd commit acf9771

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

unsloth/models/llama.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -858,11 +858,6 @@ def _CausalLM_fast_forward(
858858
logits = self.lm_head(hidden_states.to(lm_head.dtype))
859859
pass
860860
logits = logits.to(self.config.torch_dtype)
861-
if self.config.final_logit_softcapping is not None:
862-
logits = logits / self.config.final_logit_softcapping
863-
logits = torch.tanh(logits)
864-
logits = logits * self.config.final_logit_softcapping
865-
pass
866861

867862
loss = None
868863
if labels is not None:
@@ -876,6 +871,7 @@ def _CausalLM_fast_forward(
876871
loss = fast_cross_entropy_loss(
877872
logits = shift_logits,
878873
labels = shift_labels,
874+
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0),
879875
)
880876
pass
881877

0 commit comments

Comments
 (0)